{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-- {-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-- |
-- Module      : Language.Halide.Func
-- Description : Functions / Arrays
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.Func
  ( -- * Defining pipelines
    Func (..)
  , FuncTy (..)
  , Stage (..)
  , Function
  , Parameter
  , buffer
  , scalar
  , define
  , (!)
  , realizeOnTarget
  , realize

    -- * Scheduling
  , Schedulable (..)
  , TailStrategy (..)

    -- ** 'Func'-specific
  , computeRoot
  , getStage
  , getLoopLevel
  , getLoopLevelAtStage
  , asUsed
  , asUsedBy
  , copyToDevice
  , copyToHost
  , storeAt
  , computeAt
  , dim
  , estimate
  , bound
  , getArgs
  -- , deepCopy

    -- * Update definitions
  , update
  , hasUpdateDefinitions
  , getUpdateStage

    -- * Debugging
  , prettyLoopNest

    -- * Internal
  , asBufferParam
  , withFunc
  , withCxxFunc
  , withBufferParam
  , wrapCxxFunc
  , CxxStage
  , wrapCxxStage
  , withCxxStage
  )
where

import Control.Exception (bracket)
import Control.Monad (forM)
import Data.Constraint
import Data.Functor ((<&>))
import Data.IORef
import Data.Kind (Type)
import Data.Proxy
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Foreign.ForeignPtr
import Foreign.Marshal (toBool, with)
import Foreign.Ptr (Ptr, castPtr)
import GHC.Stack (HasCallStack)
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Dimension
import Language.Halide.Expr
import Language.Halide.LoopLevel
import Language.Halide.Target
import Language.Halide.Type
import Language.Halide.Utils
import System.IO.Unsafe (unsafePerformIO)
import Unsafe.Coerce
import Prelude hiding (min, tail)

-- | Haskell counterpart of [Halide::Stage](https://halide-lang.org/docs/class_halide_1_1_stage.html).
data CxxStage

importHalide

-- | A function in Halide. Conceptually, it can be thought of as a lazy
-- @n@-dimensional buffer of type @a@.
--
-- Here, @a@ is most often @'Expr' t@ for a type @t@ that is an instance of 'IsHalideType'.
-- However, one can also define @Func@s that return multiple values. In this case, @a@ will
-- be a tuple of 'Expr's.
--
-- This is a wrapper around the [@Halide::Func@](https://halide-lang.org/docs/class_halide_1_1_func.html)
-- C++ type.
data Func (t :: FuncTy) (n :: Nat) (a :: Type) where
  Func :: {-# UNPACK #-} !(ForeignPtr CxxFunc) -> Func 'FuncTy n a
  Param :: IsHalideType a => {-# UNPACK #-} !(IORef (Maybe (ForeignPtr CxxImageParam))) -> Func 'ParamTy n (Expr a)

-- | Function type. It can either be 'FuncTy' which means that we have defined the function ourselves,
-- or 'ParamTy' which means that it's a parameter to our pipeline.
data FuncTy = FuncTy | ParamTy
  deriving stock (Int -> FuncTy -> ShowS
[FuncTy] -> ShowS
FuncTy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FuncTy] -> ShowS
$cshowList :: [FuncTy] -> ShowS
show :: FuncTy -> String
$cshow :: FuncTy -> String
showsPrec :: Int -> FuncTy -> ShowS
$cshowsPrec :: Int -> FuncTy -> ShowS
Show, FuncTy -> FuncTy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FuncTy -> FuncTy -> Bool
$c/= :: FuncTy -> FuncTy -> Bool
== :: FuncTy -> FuncTy -> Bool
$c== :: FuncTy -> FuncTy -> Bool
Eq, Eq FuncTy
FuncTy -> FuncTy -> Bool
FuncTy -> FuncTy -> Ordering
FuncTy -> FuncTy -> FuncTy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: FuncTy -> FuncTy -> FuncTy
$cmin :: FuncTy -> FuncTy -> FuncTy
max :: FuncTy -> FuncTy -> FuncTy
$cmax :: FuncTy -> FuncTy -> FuncTy
>= :: FuncTy -> FuncTy -> Bool
$c>= :: FuncTy -> FuncTy -> Bool
> :: FuncTy -> FuncTy -> Bool
$c> :: FuncTy -> FuncTy -> Bool
<= :: FuncTy -> FuncTy -> Bool
$c<= :: FuncTy -> FuncTy -> Bool
< :: FuncTy -> FuncTy -> Bool
$c< :: FuncTy -> FuncTy -> Bool
compare :: FuncTy -> FuncTy -> Ordering
$ccompare :: FuncTy -> FuncTy -> Ordering
Ord)

-- | Synonym for the most commonly used function type.
type Function n a = Func 'FuncTy n (Expr a)

-- | Synonym for the most commonly used parameter type.
type Parameter n a = Func 'ParamTy n (Expr a)

-- | A single definition of a t'Func'.
newtype Stage (n :: Nat) (a :: Type) = Stage (ForeignPtr CxxStage)

-- | Different ways to handle a tail case in a split when the split factor does
-- not provably divide the extent.
--
-- This is the Haskell counterpart of [@Halide::TailStrategy@](https://halide-lang.org/docs/namespace_halide.html#a6c6557df562bd7850664e70fdb8fea0f).
data TailStrategy
  = -- | Round up the extent to be a multiple of the split factor.
    --
    -- Not legal for RVars, as it would change the meaning of the algorithm.
    --
    -- * Pros: generates the simplest, fastest code.
    -- * Cons: if used on a stage that reads from the input or writes to the
    -- output, constrains the input or output size to be a multiple of the
    -- split factor.
    TailRoundUp
  | -- | Guard the inner loop with an if statement that prevents evaluation
    -- beyond the original extent.
    --
    -- Always legal. The if statement is treated like a boundary condition, and
    -- factored out into a loop epilogue if possible.
    --
    -- * Pros: no redundant re-evaluation; does not constrain input our output sizes.
    -- * Cons: increases code size due to separate tail-case handling;
    -- vectorization will scalarize in the tail case to handle the if
    -- statement.
    TailGuardWithIf
  | -- | Guard the loads and stores in the loop with an if statement that
    -- prevents evaluation beyond the original extent.
    --
    -- Always legal. The if statement is treated like a boundary condition, and
    -- factored out into a loop epilogue if possible.
    -- * Pros: no redundant re-evaluation; does not constrain input or output
    -- sizes.
    -- * Cons: increases code size due to separate tail-case handling.
    TailPredicate
  | -- | Guard the loads in the loop with an if statement that prevents
    -- evaluation beyond the original extent.
    --
    -- Only legal for innermost splits. Not legal for RVars, as it would change
    -- the meaning of the algorithm. The if statement is treated like a
    -- boundary condition, and factored out into a loop epilogue if possible.
    -- * Pros: does not constrain input sizes, output size constraints are
    -- simpler than full predication.
    -- * Cons: increases code size due to separate tail-case handling,
    -- constrains the output size to be a multiple of the split factor.
    TailPredicateLoads
  | -- | Guard the stores in the loop with an if statement that prevents
    -- evaluation beyond the original extent.
    --
    -- Only legal for innermost splits. Not legal for RVars, as it would change
    -- the meaning of the algorithm. The if statement is treated like a
    -- boundary condition, and factored out into a loop epilogue if possible.
    -- * Pros: does not constrain output sizes, input size constraints are
    -- simpler than full predication.
    -- * Cons: increases code size due to separate tail-case handling,
    -- constraints the input size to be a multiple of the split factor.
    TailPredicateStores
  | -- | Prevent evaluation beyond the original extent by shifting the tail
    -- case inwards, re-evaluating some points near the end.
    --
    -- Only legal for pure variables in pure definitions. If the inner loop is
    -- very simple, the tail case is treated like a boundary condition and
    -- factored out into an epilogue.
    --
    -- This is a good trade-off between several factors. Like 'TailRoundUp', it
    -- supports vectorization well, because the inner loop is always a fixed
    -- size with no data-dependent branching. It increases code size slightly
    -- for inner loops due to the epilogue handling, but not for outer loops
    -- (e.g. loops over tiles). If used on a stage that reads from an input or
    -- writes to an output, this stategy only requires that the input/output
    -- extent be at least the split factor, instead of a multiple of the split
    -- factor as with 'TailRoundUp'.
    TailShiftInwards
  | -- | For pure definitions use 'TailShiftInwards'.
    --
    -- For pure vars in update definitions use 'TailRoundUp'. For RVars in update
    -- definitions use 'TailGuardWithIf'.
    TailAuto
  deriving stock (TailStrategy -> TailStrategy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TailStrategy -> TailStrategy -> Bool
$c/= :: TailStrategy -> TailStrategy -> Bool
== :: TailStrategy -> TailStrategy -> Bool
$c== :: TailStrategy -> TailStrategy -> Bool
Eq, Eq TailStrategy
TailStrategy -> TailStrategy -> Bool
TailStrategy -> TailStrategy -> Ordering
TailStrategy -> TailStrategy -> TailStrategy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TailStrategy -> TailStrategy -> TailStrategy
$cmin :: TailStrategy -> TailStrategy -> TailStrategy
max :: TailStrategy -> TailStrategy -> TailStrategy
$cmax :: TailStrategy -> TailStrategy -> TailStrategy
>= :: TailStrategy -> TailStrategy -> Bool
$c>= :: TailStrategy -> TailStrategy -> Bool
> :: TailStrategy -> TailStrategy -> Bool
$c> :: TailStrategy -> TailStrategy -> Bool
<= :: TailStrategy -> TailStrategy -> Bool
$c<= :: TailStrategy -> TailStrategy -> Bool
< :: TailStrategy -> TailStrategy -> Bool
$c< :: TailStrategy -> TailStrategy -> Bool
compare :: TailStrategy -> TailStrategy -> Ordering
$ccompare :: TailStrategy -> TailStrategy -> Ordering
Ord, Int -> TailStrategy -> ShowS
[TailStrategy] -> ShowS
TailStrategy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TailStrategy] -> ShowS
$cshowList :: [TailStrategy] -> ShowS
show :: TailStrategy -> String
$cshow :: TailStrategy -> String
showsPrec :: Int -> TailStrategy -> ShowS
$cshowsPrec :: Int -> TailStrategy -> ShowS
Show)

-- | Common scheduling functions
class KnownNat n => Schedulable f (n :: Nat) (a :: Type) where
  -- | Vectorize the dimension.
  vectorize :: VarOrRVar -> f n a -> IO (f n a)

  -- | Unroll the dimension.
  unroll :: VarOrRVar -> f n a -> IO (f n a)

  -- | Reorder variables to have the given nesting order, from innermost out.
  --
  -- Note that @variables@ should only contain variables that belong to the function.
  -- If this is not the case, a runtime error will be thrown.
  reorder
    :: [VarOrRVar]
    -- ^ variables
    -> f n a
    -- ^ function or stage
    -> IO (f n a)

  -- | Split a dimension into inner and outer subdimensions with the given names, where the inner dimension
  -- iterates from @0@ to @factor-1@.
  --
  -- The inner and outer subdimensions can then be dealt with using the other scheduling calls. It's okay
  -- to reuse the old variable name as either the inner or outer variable. The first argument specifies
  -- how the tail should be handled if the split factor does not provably divide the extent.
  split :: TailStrategy -> VarOrRVar -> (VarOrRVar, VarOrRVar) -> Expr Int32 -> f n a -> IO (f n a)

  -- | Join two dimensions into a single fused dimenion.
  --
  -- The fused dimension covers the product of the extents of the inner and outer dimensions given.
  fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a)

  -- | Mark the dimension to be traversed serially
  serial :: VarOrRVar -> f n a -> IO (f n a)

  -- | Mark the dimension to be traversed in parallel
  parallel :: VarOrRVar -> f n a -> IO (f n a)

  -- | Issue atomic updates for this Func.
  atomic
    :: Bool
    -- ^ whether to override the associativity test
    -> f n a
    -> IO (f n a)

  specialize :: Expr Bool -> f n a -> IO (Stage n a)
  specializeFail :: Text -> f n a -> IO ()
  gpuBlocks :: (KnownNat k, 1 <= k, k <= 3) => DeviceAPI -> IndexType k -> f n a -> IO (f n a)
  gpuThreads :: (KnownNat k, 1 <= k, k <= 3) => DeviceAPI -> IndexType k -> f n a -> IO (f n a)
  gpuLanes :: DeviceAPI -> VarOrRVar -> f n a -> IO (f n a)

  -- | Schedule the iteration over this stage to be fused with another stage from outermost loop to a
  -- given LoopLevel.
  --
  -- For more info, see [Halide::Stage::compute_with](https://halide-lang.org/docs/class_halide_1_1_stage.html#a82a2ae25a009d6a2d52cb407a25f0a5b).
  computeWith :: LoopAlignStrategy -> f n a -> LoopLevel t -> IO ()

-- | GHC is not able to automatically prove the transitivity property for type-level naturals. We help GHC out 😀.
proveTransitivityOfLessThanEqual :: (KnownNat k, KnownNat l, KnownNat m, k <= l, l <= m) => Dict (k <= m)
proveTransitivityOfLessThanEqual :: forall (k :: Nat) (l :: Nat) (m :: Nat).
(KnownNat k, KnownNat l, KnownNat m, k <= l, l <= m) =>
Dict (k <= m)
proveTransitivityOfLessThanEqual = forall a b. a -> b
unsafeCoerce forall a b. (a -> b) -> a -> b
$ forall (a :: Constraint). a => Dict a
Dict @(1 <= 2)

instance KnownNat n => Schedulable Stage n a where
  vectorize :: VarOrRVar -> Stage n a -> IO (Stage n a)
vectorize VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->vectorize(*$(const Halide::VarOrRVar* var'));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  unroll :: VarOrRVar -> Stage n a -> IO (Stage n a)
unroll VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->unroll(*$(const Halide::VarOrRVar* var'));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  reorder :: [VarOrRVar] -> Stage n a -> IO (Stage n a)
reorder [VarOrRVar]
args Stage n a
stage = do
    forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar [VarOrRVar]
args forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
args' -> do
      forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=]() {
            $(Halide::Stage* stage')->reorder(
              *$(const std::vector<Halide::VarOrRVar>* args')); 
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  split :: TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> Stage n a
-> IO (Stage n a)
split TailStrategy
tail VarOrRVar
old (VarOrRVar
outer, VarOrRVar
inner) VarOrRVar
factor Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
old forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
old' ->
        forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
outer forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
outer' ->
          forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
inner forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
inner' ->
            forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
factor forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
factor' ->
              [C.throwBlock| void {
                handle_halide_exceptions([=](){
                  $(Halide::Stage* stage')->split(
                    *$(const Halide::VarOrRVar* old'),
                    *$(const Halide::VarOrRVar* outer'),
                    *$(const Halide::VarOrRVar* inner'),
                    *$(const Halide::Expr* factor'),
                    static_cast<Halide::TailStrategy>($(int t)));
                });
              } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
    where
      t :: CInt
t = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ TailStrategy
tail
  fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> Stage n a -> IO (Stage n a)
fuse (VarOrRVar
outer, VarOrRVar
inner) VarOrRVar
fused Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
outer forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
outer' ->
        forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
inner forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
inner' ->
          forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
fused forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
fused' ->
            [C.throwBlock| void {
              handle_halide_exceptions([=](){
                $(Halide::Stage* stage')->fuse(
                  *$(const Halide::VarOrRVar* outer'),
                  *$(const Halide::VarOrRVar* inner'),
                  *$(const Halide::VarOrRVar* fused'));
              });
            } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  serial :: VarOrRVar -> Stage n a -> IO (Stage n a)
serial VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->serial(*$(const Halide::VarOrRVar* var'));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  parallel :: VarOrRVar -> Stage n a -> IO (Stage n a)
parallel VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->parallel(*$(const Halide::VarOrRVar* var'));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  atomic :: Bool -> Stage n a -> IO (Stage n a)
atomic (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CBool
override) Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      [C.throwBlock| void {
        handle_halide_exceptions([=](){
          $(Halide::Stage* stage')->atomic($(bool override));
        });
      } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage

  specialize :: Expr Bool -> Stage n a -> IO (Stage n a)
specialize Expr Bool
cond Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Bool
cond forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
cond' ->
        forall (n :: Nat) a. Ptr CxxStage -> IO (Stage n a)
wrapCxxStage
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::Stage* {
                return handle_halide_exceptions([=](){
                  return new Halide::Stage{$(Halide::Stage* stage')->specialize(
                    *$(const Halide::Expr* cond'))};
                });
              } |]
  specializeFail :: Text -> Stage n a -> IO ()
specializeFail (Text -> ByteString
T.encodeUtf8 -> ByteString
s) Stage n a
stage =
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      [C.throwBlock| void {
        return handle_halide_exceptions([=](){
          $(Halide::Stage* stage')->specialize_fail(
            std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)});
        });
      } |]
  gpuBlocks :: forall k. (KnownNat k, 1 <= k, k <= 3) => DeviceAPI -> IndexType k -> Stage n a -> IO (Stage n a)
  gpuBlocks :: forall (k :: Nat).
(KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> Stage n a -> IO (Stage n a)
gpuBlocks (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api :: C.CInt) IndexType k
vars Stage n a
stage =
    case forall (k :: Nat) (l :: Nat) (m :: Nat).
(KnownNat k, KnownNat l, KnownNat m, k <= l, l <= m) =>
Dict (k <= m)
proveTransitivityOfLessThanEqual @k @3 @10 of
      Dict (k <= 10)
Dict -> case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @k of
        Sub Dict (IndexTypeProperties k)
(KnownNat k, k <= 10) => Dict (IndexTypeProperties k)
Dict ->
          forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
            forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType k
vars) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
vars' -> do
              [C.throwBlock| void {
                handle_halide_exceptions([=](){
                  auto const& vars = *$(const std::vector<Halide::VarOrRVar>* vars');
                  auto& stage = *$(Halide::Stage* stage');
                  auto const device = static_cast<Halide::DeviceAPI>($(int api));
                  switch (vars.size()) {
                    case 1: stage.gpu_blocks(vars.at(0), device);
                            break;
                    case 2: stage.gpu_blocks(vars.at(0), vars.at(1), device);
                            break;
                    case 3: stage.gpu_blocks(vars.at(0), vars.at(1), vars.at(2), device);
                            break;
                    default: throw std::runtime_error{"unexpected number of arguments in gpuBlocks"};
                  }
                });
              } |]
              forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  gpuThreads :: forall k. (KnownNat k, 1 <= k, k <= 3) => DeviceAPI -> IndexType k -> Stage n a -> IO (Stage n a)
  gpuThreads :: forall (k :: Nat).
(KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> Stage n a -> IO (Stage n a)
gpuThreads (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api :: C.CInt) IndexType k
vars Stage n a
stage =
    case forall (k :: Nat) (l :: Nat) (m :: Nat).
(KnownNat k, KnownNat l, KnownNat m, k <= l, l <= m) =>
Dict (k <= m)
proveTransitivityOfLessThanEqual @k @3 @10 of
      Dict (k <= 10)
Dict -> case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @k of
        Sub Dict (IndexTypeProperties k)
(KnownNat k, k <= 10) => Dict (IndexTypeProperties k)
Dict -> do
          forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
            forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType k
vars) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
vars' -> do
              [C.throwBlock| void {
                handle_halide_exceptions([=](){
                  auto const& vars = *$(const std::vector<Halide::VarOrRVar>* vars');
                  auto& stage = *$(Halide::Stage* stage');
                  auto const device = static_cast<Halide::DeviceAPI>($(int api));
                  switch (vars.size()) {
                    case 1: stage.gpu_threads(vars.at(0), device);
                            break;
                    case 2: stage.gpu_threads(vars.at(0), vars.at(1), device);
                            break;
                    case 3: stage.gpu_threads(vars.at(0), vars.at(1), vars.at(2), device);
                            break;
                    default: throw std::runtime_error{"unexpected number of arguments in gpuThreads"};
                  }
                });
              } |]
          forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  gpuLanes :: DeviceAPI -> VarOrRVar -> Stage n a -> IO (Stage n a)
gpuLanes (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) VarOrRVar
var Stage n a
stage = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Stage* stage')->gpu_lanes(
              *$(const Halide::VarOrRVar* var'),
              static_cast<Halide::DeviceAPI>($(int api)));
          });
        } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
  computeWith :: forall (t :: LoopLevelTy).
LoopAlignStrategy -> Stage n a -> LoopLevel t -> IO ()
computeWith (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
align) Stage n a
stage LoopLevel t
level = do
    forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
      forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
level' ->
        [C.throwBlock| void {
          handle_halide_exceptions([=]() {
            $(Halide::Stage* stage')->compute_with(
              *$(const Halide::LoopLevel* level'),
              static_cast<Halide::LoopAlignStrategy>($(int align)));
          });
        } |]

viaStage1
  :: KnownNat n
  => (a -> Stage n b -> IO (Stage n b))
  -> a
  -> Func t n b
  -> IO (Func t n b)
viaStage1 :: forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 a -> Stage n b -> IO (Stage n b)
f a
a1 Func t n b
func = do
  Stage n b
_ <- a -> Stage n b -> IO (Stage n b)
f a
a1 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func

viaStage2
  :: (KnownNat n)
  => (a1 -> a2 -> Stage n b -> IO (Stage n b))
  -> a1
  -> a2
  -> Func t n b
  -> IO (Func t n b)
viaStage2 :: forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 a1 -> a2 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 Func t n b
func = do
  Stage n b
_ <- a1 -> a2 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func

{-
viaStage3
  :: (KnownNat n, IsHalideType b)
  => (a1 -> a2 -> a3 -> Stage n b -> IO (Stage n b))
  -> a1
  -> a2
  -> a3
  -> Func t n b
  -> IO (Func t n b)
viaStage3 f a1 a2 a3 func = do
  _ <- f a1 a2 a3 =<< getStage func
  pure func
-}

viaStage4
  :: (KnownNat n)
  => (a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
  -> a1
  -> a2
  -> a3
  -> a4
  -> Func t n b
  -> IO (Func t n b)
viaStage4 :: forall (n :: Nat) a1 a2 a3 a4 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b)
viaStage4 a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 a3
a3 a4
a4 Func t n b
func = do
  Stage n b
_ <- a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 a3
a3 a4
a4 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func

instance KnownNat n => Schedulable (Func t) n a where
  vectorize :: VarOrRVar -> Func t n a -> IO (Func t n a)
vectorize = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
vectorize
  unroll :: VarOrRVar -> Func t n a -> IO (Func t n a)
unroll = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
unroll
  reorder :: [VarOrRVar] -> Func t n a -> IO (Func t n a)
reorder = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
[VarOrRVar] -> f n a -> IO (f n a)
reorder
  split :: TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> Func t n a
-> IO (Func t n a)
split = forall (n :: Nat) a1 a2 a3 a4 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b)
viaStage4 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> f n a
-> IO (f n a)
split
  fuse :: (VarOrRVar, VarOrRVar)
-> VarOrRVar -> Func t n a -> IO (Func t n a)
fuse = forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
(VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a)
fuse
  serial :: VarOrRVar -> Func t n a -> IO (Func t n a)
serial = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
serial
  parallel :: VarOrRVar -> Func t n a -> IO (Func t n a)
parallel = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
parallel
  atomic :: Bool -> Func t n a -> IO (Func t n a)
atomic = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Bool -> f n a -> IO (f n a)
atomic
  specialize :: Expr Bool -> Func t n a -> IO (Stage n a)
specialize Expr Bool
cond Func t n a
func = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Bool -> f n a -> IO (Stage n a)
specialize Expr Bool
cond
  specializeFail :: Text -> Func t n a -> IO ()
specializeFail Text
msg Func t n a
func = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Text -> f n a -> IO ()
specializeFail Text
msg
  gpuBlocks :: forall (k :: Nat).
(KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> Func t n a -> IO (Func t n a)
gpuBlocks = forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a (k :: Nat).
(Schedulable f n a, KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuBlocks
  gpuThreads :: forall (k :: Nat).
(KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> Func t n a -> IO (Func t n a)
gpuThreads = forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a (k :: Nat).
(Schedulable f n a, KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuThreads
  gpuLanes :: DeviceAPI -> VarOrRVar -> Func t n a -> IO (Func t n a)
gpuLanes = forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
DeviceAPI -> VarOrRVar -> f n a -> IO (f n a)
gpuLanes
  computeWith :: forall (t :: LoopLevelTy).
LoopAlignStrategy -> Func t n a -> LoopLevel t -> IO ()
computeWith LoopAlignStrategy
a Func t n a
f LoopLevel t
l = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n a
f forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Stage n a
f' -> forall (f :: Nat -> * -> *) (n :: Nat) a (t :: LoopLevelTy).
Schedulable f n a =>
LoopAlignStrategy -> f n a -> LoopLevel t -> IO ()
computeWith LoopAlignStrategy
a Stage n a
f' LoopLevel t
l

instance Enum TailStrategy where
  fromEnum :: TailStrategy -> Int
fromEnum =
    forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
      TailStrategy
TailRoundUp -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::RoundUp) } |]
      TailStrategy
TailGuardWithIf -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::GuardWithIf) } |]
      TailStrategy
TailPredicate -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::Predicate) } |]
      TailStrategy
TailPredicateLoads -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateLoads) } |]
      TailStrategy
TailPredicateStores -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateStores) } |]
      TailStrategy
TailShiftInwards -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::ShiftInwards) } |]
      TailStrategy
TailAuto -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::Auto) } |]
  toEnum :: Int -> TailStrategy
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::RoundUp) } |] = TailStrategy
TailRoundUp
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::GuardWithIf) } |] = TailStrategy
TailGuardWithIf
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::Predicate) } |] = TailStrategy
TailPredicate
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateLoads) } |] = TailStrategy
TailPredicateLoads
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateStores) } |] = TailStrategy
TailPredicateStores
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::ShiftInwards) } |] = TailStrategy
TailShiftInwards
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::Auto) } |] = TailStrategy
TailAuto
    | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid TailStrategy: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k

-- | Statically declare the range over which the function will be evaluated in the general case.
--
-- This provides a basis for the auto scheduler to make trade-offs and scheduling decisions.
-- The auto generated schedules might break when the sizes of the dimensions are very different from the
-- estimates specified. These estimates are used only by the auto scheduler if the function is a pipeline output.
estimate
  :: KnownNat n
  => Expr Int32
  -- ^ index variable
  -> Expr Int32
  -- ^ @min@ estimate
  -> Expr Int32
  -- ^ @extent@ estimate
  -> Func t n a
  -> IO ()
estimate :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
VarOrRVar -> VarOrRVar -> VarOrRVar -> Func t n a -> IO ()
estimate VarOrRVar
var VarOrRVar
start VarOrRVar
extent Func t n a
func =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
i ->
      forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
start forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
minExpr ->
        forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extentExpr ->
          [CU.exp| void {
            $(Halide::Func* f)->set_estimate(
              *$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |]

-- | Statically declare the range over which a function should be evaluated.
--
-- This can let Halide perform some optimizations. E.g. if you know there are going to be 4 color channels,
-- you can completely vectorize the color channel dimension without the overhead of splitting it up.
-- If bounds inference decides that it requires more of this function than the bounds you have stated,
-- a runtime error will occur when you try to run your pipeline.
bound
  :: KnownNat n
  => Expr Int32
  -- ^ index variable
  -> Expr Int32
  -- ^ @min@ estimate
  -> Expr Int32
  -- ^ @extent@ estimate
  -> Func t n a
  -> IO ()
bound :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
VarOrRVar -> VarOrRVar -> VarOrRVar -> Func t n a -> IO ()
bound VarOrRVar
var VarOrRVar
start VarOrRVar
extent Func t n a
func =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
i ->
      forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
start forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
minExpr ->
        forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extentExpr ->
          [CU.exp| void {
            $(Halide::Func* f)->bound(
              *$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |]

-- | Get the index arguments of the function.
--
-- The returned list contains exactly @n@ elements.
getArgs :: KnownNat n => Func t n a -> IO [Var]
getArgs :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO [VarOrRVar]
getArgs Func t n a
func =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' -> do
    let allocate :: IO (Ptr (CxxVector CxxVar))
allocate =
          [CU.exp| std::vector<Halide::Var>* { 
            new std::vector<Halide::Var>{$(const Halide::Func* func')->args()} } |]
        destroy :: Ptr (CxxVector CxxVar) -> IO ()
destroy Ptr (CxxVector CxxVar)
v = [CU.exp| void { delete $(std::vector<Halide::Var>* v) } |]
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr (CxxVector CxxVar))
allocate Ptr (CxxVector CxxVar) -> IO ()
destroy forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVar)
v -> do
      CSize
n <- [CU.exp| size_t { $(const std::vector<Halide::Var>* v)->size() } |]
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CSize
0 .. CSize
n forall a. Num a => a -> a -> a
- CSize
1] forall a b. (a -> b) -> a -> b
$ \CSize
i ->
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxVar -> Expr a
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
ptr ->
          [CU.exp| void {
            new ($(Halide::Var* ptr)) Halide::Var{$(const std::vector<Halide::Var>* v)->at($(size_t i))} } |]

-- | Compute all of this function once ahead of time.
--
-- See [Halide::Func::compute_root](https://halide-lang.org/docs/class_halide_1_1_func.html#a29df45a4a16a63eb81407261a9783060) for more info.
computeRoot :: KnownNat n => Func t n a -> IO (Func t n a)
computeRoot :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func t n a)
computeRoot Func t n a
func = do
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    [C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Func* f)->compute_root(); }); } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
func

-- | Creates and returns a new identity Func that wraps this Func.
--
-- During compilation, Halide replaces all calls to this Func done by 'f' with calls to the wrapper.
-- If this Func is already wrapped for use in 'f', will return the existing wrapper.
--
-- For more info, see [Halide::Func::in](https://halide-lang.org/docs/class_halide_1_1_func.html#a9d619f2d0111ea5bf640781d1324d050).
asUsedBy
  :: (KnownNat n, KnownNat m)
  => Func t1 n a
  -> Func 'FuncTy m b
  -> IO (Func 'FuncTy n a)
asUsedBy :: forall (n :: Nat) (m :: Nat) (t1 :: FuncTy) a b.
(KnownNat n, KnownNat m) =>
Func t1 n a -> Func 'FuncTy m b -> IO (Func 'FuncTy n a)
asUsedBy Func t1 n a
g Func 'FuncTy m b
f =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t1 n a
g forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
gPtr -> forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy m b
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
fPtr ->
    forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* {
            new Halide::Func{$(Halide::Func* gPtr)->in(*$(Halide::Func* fPtr))} } |]

-- | Create and return a global identity wrapper, which wraps all calls to this Func by any other Func.
--
-- If a global wrapper already exists, returns it. The global identity wrapper is only used by callers
-- for which no custom wrapper has been specified.
asUsed :: KnownNat n => Func t n a -> IO (Func 'FuncTy n a)
asUsed :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func 'FuncTy n a)
asUsed Func t n a
f =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
fPtr ->
    forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* { new Halide::Func{$(Halide::Func* fPtr)->in()} } |]

-- | Declare that this function should be implemented by a call to @halide_buffer_copy@ with the given
-- target device API.
--
-- Asserts that the @Func@ has a pure definition which is a simple call to a single input, and no update
-- definitions. The wrapper @Func@s returned by 'asUsed' are suitable candidates. Consumes all pure variables,
-- and rewrites the @Func@ to have an extern definition that calls @halide_buffer_copy@.
copyToDevice :: KnownNat n => DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice DeviceAPI
deviceApi Func t n a
func = do
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    [C.throwBlock| void {
      handle_halide_exceptions([=](){
        $(Halide::Func* f)->copy_to_device(static_cast<Halide::DeviceAPI>($(int api)));
      });
    } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
func
  where
    api :: CInt
api = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ DeviceAPI
deviceApi

-- | Same as @'copyToDevice' 'DeviceHost'@
copyToHost :: KnownNat n => Func t n a -> IO (Func t n a)
copyToHost :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func t n a)
copyToHost = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice DeviceAPI
DeviceHost

mkBufferParameter
  :: forall n a. (KnownNat n, IsHalideType a) => Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter Maybe Text
maybeName = do
  forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
    let d :: CInt
d = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)
        createWithoutName :: IO (Ptr CxxImageParam)
createWithoutName =
          [CU.exp| Halide::ImageParam* {
            new Halide::ImageParam{Halide::Type{*$(halide_type_t* t)}, $(int d)} } |]
        deleter :: FunPtr (Ptr CxxImageParam -> IO ())
deleter = [C.funPtr| void deleteImageParam(Halide::ImageParam* p) { delete p; } |]
        createWithName :: Text -> IO (Ptr CxxImageParam)
createWithName Text
name =
          let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
           in [CU.exp| Halide::ImageParam* {
                new Halide::ImageParam{
                      Halide::Type{*$(halide_type_t* t)},
                      $(int d),
                      std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}} } |]
    forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxImageParam -> IO ())
deleter forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO (Ptr CxxImageParam)
createWithoutName Text -> IO (Ptr CxxImageParam)
createWithName Maybe Text
maybeName

getBufferParameter
  :: forall n a
   . (KnownNat n, IsHalideType a)
  => Maybe Text
  -> IORef (Maybe (ForeignPtr CxxImageParam))
  -> IO (ForeignPtr CxxImageParam)
getBufferParameter :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter Maybe Text
name IORef (Maybe (ForeignPtr CxxImageParam))
r =
  forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxImageParam))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just ForeignPtr CxxImageParam
fp -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxImageParam
fp
    Maybe (ForeignPtr CxxImageParam)
Nothing -> do
      ForeignPtr CxxImageParam
fp <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter @n @a Maybe Text
name
      forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxImageParam))
r (forall a. a -> Maybe a
Just ForeignPtr CxxImageParam
fp)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxImageParam
fp

-- | Same as 'withFunc', but ensures that we're dealing with 'Param' instead of a 'Func'.
withBufferParam
  :: forall n a b
   . (HasCallStack, KnownNat n, IsHalideType a)
  => Func 'ParamTy n (Expr a)
  -> (Ptr CxxImageParam -> IO b)
  -> IO b
withBufferParam :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n (Expr a) -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam (Param IORef (Maybe (ForeignPtr CxxImageParam))
r) Ptr CxxImageParam -> IO b
action =
  forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxImageParam))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr Ptr CxxImageParam -> IO b
action

-- | Get the underlying pointer to @Halide::Func@ and invoke an 'IO' action with it.
withFunc :: KnownNat n => Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc :: forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f Ptr CxxFunc -> IO b
action = case Func t n a
f of
  Func ForeignPtr CxxFunc
fp -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxFunc
fp Ptr CxxFunc -> IO b
action
  p :: Func t n a
p@(Param IORef (Maybe (ForeignPtr CxxImageParam))
_) -> forall {k} (t :: FuncTy) (n :: Nat) (a :: k).
KnownNat n =>
Func t n (Expr a) -> IO (Func 'FuncTy n (Expr a))
forceFunc Func t n a
p forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Func ForeignPtr CxxFunc
fp) -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxFunc
fp Ptr CxxFunc -> IO b
action

withCxxFunc :: KnownNat n => Func 'FuncTy n a -> (Ptr CxxFunc -> IO b) -> IO b
withCxxFunc :: forall (n :: Nat) a b.
KnownNat n =>
Func 'FuncTy n a -> (Ptr CxxFunc -> IO b) -> IO b
withCxxFunc (Func ForeignPtr CxxFunc
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxFunc
fp

wrapCxxFunc :: Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc :: forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat) a. ForeignPtr CxxFunc -> Func 'FuncTy n a
Func forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxFunc -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxFunc -> IO ())
deleter = [C.funPtr| void deleteFunc(Halide::Func *x) { delete x; } |]

forceFunc :: forall t n a. KnownNat n => Func t n (Expr a) -> IO (Func 'FuncTy n (Expr a))
forceFunc :: forall {k} (t :: FuncTy) (n :: Nat) (a :: k).
KnownNat n =>
Func t n (Expr a) -> IO (Func 'FuncTy n (Expr a))
forceFunc = \case
  x :: Func t n (Expr a)
x@(Func ForeignPtr CxxFunc
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n (Expr a)
x
  (Param IORef (Maybe (ForeignPtr CxxImageParam))
r) -> do
    ForeignPtr CxxImageParam
fp <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxImageParam))
r
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxImageParam
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
p ->
      forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* {
              new Halide::Func{static_cast<Halide::Func>(*$(Halide::ImageParam* p))} } |]

class IsFuncDefinition d where
  definitionToExprList :: d -> [ForeignPtr CxxExpr]
  exprListToDefinition :: [ForeignPtr CxxExpr] -> d

instance IsHalideType a => IsFuncDefinition (Expr a) where
  definitionToExprList :: Expr a -> [ForeignPtr CxxExpr]
definitionToExprList = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr
  exprListToDefinition :: [ForeignPtr CxxExpr] -> Expr a
exprListToDefinition [ForeignPtr CxxExpr
x1] = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x1 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x1)
  exprListToDefinition [ForeignPtr CxxExpr]
_ = forall a. HasCallStack => String -> a
error String
"should never happen"

instance (IsHalideType a1, IsHalideType a2) => IsFuncDefinition (Expr a1, Expr a2) where
  definitionToExprList :: (Expr a1, Expr a2) -> [ForeignPtr CxxExpr]
definitionToExprList (Expr a1
x1, Expr a2
x2) = [forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a1
x1, forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a2
x2]
  exprListToDefinition :: [ForeignPtr CxxExpr] -> (Expr a1, Expr a2)
exprListToDefinition [ForeignPtr CxxExpr
x1, ForeignPtr CxxExpr
x2] = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x1 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a1)
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x2 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a2)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x1, forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x2)
  exprListToDefinition [ForeignPtr CxxExpr]
_ = forall a. HasCallStack => String -> a
error String
"should never happen"

instance (IsHalideType a1, IsHalideType a2, IsHalideType a3) => IsFuncDefinition (Expr a1, Expr a2, Expr a3) where
  definitionToExprList :: (Expr a1, Expr a2, Expr a3) -> [ForeignPtr CxxExpr]
definitionToExprList (Expr a1
x1, Expr a2
x2, Expr a3
x3) = [forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a1
x1, forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a2
x2, forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a3
x3]
  exprListToDefinition :: [ForeignPtr CxxExpr] -> (Expr a1, Expr a2, Expr a3)
exprListToDefinition [ForeignPtr CxxExpr
x1, ForeignPtr CxxExpr
x2, ForeignPtr CxxExpr
x3] = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x1 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a1)
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x2 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a2)
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x3 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a3)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x1, forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x2, forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x3)
  exprListToDefinition [ForeignPtr CxxExpr]
_ = forall a. HasCallStack => String -> a
error String
"should never happen"

-- | Define a Halide function.
--
-- @define "f" i e@ defines a Halide function called "f" such that @f[i] = e@.
--
-- Here, @i@ is an @n@-element tuple of t'Var', i.e. the following are all valid:
--
-- >>> [x, y, z] <- mapM mkVar ["x", "y", "z"]
-- >>> f1 <- define "f1" x (0 :: Expr Float)
-- >>> f2 <- define "f2" (x, y) (0 :: Expr Float)
-- >>> f3 <- define "f3" (x, y, z) (0 :: Expr Float)
define
  :: forall n d
   . (HasIndexType n, IsFuncDefinition d)
  => Text
  -> IndexType n
  -> d
  -> IO (Func 'FuncTy n d)
define :: forall (n :: Nat) d.
(HasIndexType n, IsFuncDefinition d) =>
Text -> IndexType n -> d -> IO (Func 'FuncTy n d)
define Text
name IndexType n
args d
definition =
  case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @n of
    Sub Dict (IndexTypeProperties n)
HasIndexType n => Dict (IndexTypeProperties n)
Dict -> forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType n
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVar)
x -> do
      let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
      forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (forall d. IsFuncDefinition d => d -> [ForeignPtr CxxExpr]
definitionToExprList d
definition) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
v ->
        forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.block| Halide::Func* {
                Halide::Func f{std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}};
                auto const& args = *$(const std::vector<Halide::Var>* x);
                auto const& def = *$(const std::vector<Halide::Expr>* v);
                if (def.size() == 1) {
                  f(args) = def.at(0);
                }
                else {
                  f(args) = Halide::Tuple{def};
                }
                return new Halide::Func{f};
              } |]

-- | Create an update definition for a Halide function.
--
-- @update f i e@ creates an update definition for @f@ that performs @f[i] = e@.
update
  :: forall n d
   . (HasIndexType n, IsFuncDefinition d)
  => Func 'FuncTy n d
  -> IndexType n
  -> d
  -> IO ()
update :: forall (n :: Nat) d.
(HasIndexType n, IsFuncDefinition d) =>
Func 'FuncTy n d -> IndexType n -> d -> IO ()
update Func 'FuncTy n d
func IndexType n
args d
definition =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n d
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @n of
      Sub Dict (IndexTypeProperties n)
HasIndexType n => Dict (IndexTypeProperties n)
Dict -> forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType n
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
index ->
        forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (forall d. IsFuncDefinition d => d -> [ForeignPtr CxxExpr]
definitionToExprList d
definition) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
value ->
          [C.throwBlock| void {
            handle_halide_exceptions([=](){
              auto& f = *$(Halide::Func* f);
              auto const& index = *$(const std::vector<Halide::Expr>* index);
              auto const& value = *$(const std::vector<Halide::Expr>* value);
              if (value.size() == 1) {
                f(index) = value.at(0);
              }
              else {
                f(index) = Halide::Tuple{value};
              }
            });
          } |]

infix 9 !

withExprIndices :: forall n a. HasIndexType n => IndexType n -> (Ptr (CxxVector CxxExpr) -> IO a) -> IO a
withExprIndices :: forall (n :: Nat) a.
HasIndexType n =>
IndexType n -> (Ptr (CxxVector CxxExpr) -> IO a) -> IO a
withExprIndices IndexType n
indices Ptr (CxxVector CxxExpr) -> IO a
action =
  case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @n of
    Sub Dict (IndexTypeProperties n)
HasIndexType n => Dict (IndexTypeProperties n)
Dict -> forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType n
indices) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
x ->
      Ptr (CxxVector CxxExpr) -> IO a
action Ptr (CxxVector CxxExpr)
x

indexFunc :: forall n a t. HasIndexType n => Func t n a -> IndexType n -> IO [ForeignPtr CxxExpr]
indexFunc :: forall (n :: Nat) a (t :: FuncTy).
HasIndexType n =>
Func t n a -> IndexType n -> IO [ForeignPtr CxxExpr]
indexFunc Func t n a
func IndexType n
indices = forall (n :: Nat) a.
HasIndexType n =>
IndexType n -> (Ptr (CxxVector CxxExpr) -> IO a) -> IO a
withExprIndices IndexType n
indices forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
x ->
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> do
    let allocate :: IO (Ptr (CxxVector CxxExpr))
allocate =
          [CU.block| std::vector<Halide::Expr>* {
            Halide::FuncRef ref = $(Halide::Func* f)->operator()(*$(std::vector<Halide::Expr>* x));
            std::vector<Halide::Expr> v;
            if (ref.size() == 1) {
              v.push_back(static_cast<Halide::Expr>(ref));
            }
            else {
              for (auto i = size_t{0}; i < ref.size(); ++i) {
                v.push_back(ref[i]);
              }
            }
            return new std::vector<Halide::Expr>{std::move(v)};
          } |]
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr (CxxVector CxxExpr))
allocate forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
v -> do
      CSize
size <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. HasCxxVector a => Ptr (CxxVector a) -> IO Int
cxxVectorSize Ptr (CxxVector CxxExpr)
v
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CSize
0 .. CSize
size forall a. Num a => a -> a -> a
- CSize
1] forall a b. (a -> b) -> a -> b
$ \CSize
i ->
        forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
          [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
                $(const std::vector<Halide::Expr>* v)->at($(size_t i))} } |]

-- | Apply a Halide function. Conceptually, @f ! i@ is equivalent to @f[i]@, i.e.
-- indexing into a lazy array.
(!) :: (HasIndexType n, IsFuncDefinition a) => Func t n a -> IndexType n -> a
! :: forall (n :: Nat) a (t :: FuncTy).
(HasIndexType n, IsFuncDefinition a) =>
Func t n a -> IndexType n -> a
(!) Func t n a
func IndexType n
args = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) a (t :: FuncTy).
HasIndexType n =>
Func t n a -> IndexType n -> IO [ForeignPtr CxxExpr]
indexFunc Func t n a
func IndexType n
args forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall d. IsFuncDefinition d => [ForeignPtr CxxExpr] -> d
exprListToDefinition

-- | Get a particular dimension of a pipeline parameter.
dim
  :: forall n a
   . (HasCallStack, KnownNat n)
  => Int
  -> Func 'ParamTy n (Expr a)
  -> IO Dimension
dim :: forall {k} (n :: Nat) (a :: k).
(HasCallStack, KnownNat n) =>
Int -> Func 'ParamTy n (Expr a) -> IO Dimension
dim Int
k func :: Func 'ParamTy n (Expr a)
func@(Param IORef (Maybe (ForeignPtr CxxImageParam))
_)
  | Int
0 forall a. Ord a => a -> a -> Bool
<= Int
k Bool -> Bool -> Bool
&& Int
k forall a. Ord a => a -> a -> Bool
< forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) =
      let n :: CInt
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
       in forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n (Expr a) -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam Func 'ParamTy n (Expr a)
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
f ->
            Ptr CxxDimension -> IO Dimension
wrapCxxDimension
              forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Internal::Dimension* {
                    new Halide::Internal::Dimension{$(Halide::ImageParam* f)->dim($(int n))} } |]
  | Bool
otherwise =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"invalid dimension index: "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
          forall a. Semigroup a => a -> a -> a
<> String
"; Func is "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
          forall a. Semigroup a => a -> a -> a
<> String
"-dimensional"

-- | Get the loop nests specified by the schedule for this function.
--
-- Helpful for understanding what a schedule is doing.
--
-- For more info, see
-- [@Halide::Func::print_loop_nest@](https://halide-lang.org/docs/class_halide_1_1_func.html#a03f839d9e13cae4b87a540aa618589ae)
prettyLoopNest :: KnownNat n => Func t n r -> IO Text
prettyLoopNest :: forall (n :: Nat) (t :: FuncTy) r.
KnownNat n =>
Func t n r -> IO Text
prettyLoopNest Func t n r
func = forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n r
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
  Ptr CxxString -> IO Text
peekAndDeleteCxxString
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| std::string* {
          return handle_halide_exceptions([=]() {
            return new std::string{Halide::Internal::print_loop_nest(
              std::vector<Halide::Internal::Function>{$(Halide::Func* f)->function()})};
          });
        } |]

-- | Similar to 'realizeOnTarget' except that the pipeline is run on 'hostTarget'.
realize
  :: forall n a t b
   . (KnownNat n, IsHalideType a)
  => Func t n (Expr a)
  -- ^ Function to evaluate
  -> [Int]
  -- ^ Domain over which to evaluate
  -> (Ptr (HalideBuffer n a) -> IO b)
  -- ^ What to do with the buffer afterwards. Note that the buffer is allocated only temporary,
  -- so do not return it directly.
  -> IO b
realize :: forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n (Expr a)
-> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
realize = forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Target
-> Func t n (Expr a)
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
realizeOnTarget Target
hostTarget

-- | Evaluate this function over a rectangular domain.
--
-- If your target is a GPU, this function will not automatically copy data back from the GPU.
realizeOnTarget
  :: forall n a t b
   . (KnownNat n, IsHalideType a)
  => Target
  -- ^ Target on which to run the pipeline
  -> Func t n (Expr a)
  -- ^ Function to evaluate
  -> [Int]
  -- ^ Domain over which to evaluate
  -> (Ptr (HalideBuffer n a) -> IO b)
  -- ^ What to do with the buffer afterwards. Note that the buffer is allocated only temporary,
  -- so do not return it directly.
  -> IO b
realizeOnTarget :: forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Target
-> Func t n (Expr a)
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
realizeOnTarget Target
target Func t n (Expr a)
func [Int]
shape Ptr (HalideBuffer n a) -> IO b
action =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n (Expr a)
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
    forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Target -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaBuffer Target
target [Int]
shape forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
buf -> do
        let raw :: Ptr RawHalideBuffer
raw = forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
buf
        [C.throwBlock| void {
          handle_halide_exceptions([=](){
            $(Halide::Func* func')->realize($(halide_buffer_t* raw), *$(const Halide::Target* target'));
          });
        } |]
        Ptr (HalideBuffer n a) -> IO b
action Ptr (HalideBuffer n a)
buf

-- | A view pattern to specify the name of a buffer argument.
--
-- Example usage:
--
-- >>> :{
-- _ <- compile $ \(buffer "src" -> src) -> do
--   i <- mkVar "i"
--   define "dest" i $ (src ! i :: Expr Float)
-- :}
--
-- or if we want to specify the dimension and type, we can use type applications:
--
-- >>> :{
-- _ <- compile $ \(buffer @1 @Float "src" -> src) -> do
--   i <- mkVar "i"
--   define "dest" i $ src ! i
-- :}
buffer :: forall n a. (KnownNat n, IsHalideType a) => Text -> Func 'ParamTy n (Expr a) -> Func 'ParamTy n (Expr a)
buffer :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Text -> Func 'ParamTy n (Expr a) -> Func 'ParamTy n (Expr a)
buffer Text
name p :: Func 'ParamTy n (Expr a)
p@(Param IORef (Maybe (ForeignPtr CxxImageParam))
r) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  ForeignPtr CxxImageParam
_ <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a (forall a. a -> Maybe a
Just Text
name) IORef (Maybe (ForeignPtr CxxImageParam))
r
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'ParamTy n (Expr a)
p

-- | Similar to 'buffer', but for scalar parameters.
--
-- Example usage:
--
-- >>> :{
-- _ <- compile $ \(scalar @Float "a" -> a) -> do
--   i <- mkVar "i"
--   define "dest" i $ a
-- :}
scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a
scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a
scalar Text
name (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxParameter))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just ForeignPtr CxxParameter
_ -> forall a. HasCallStack => String -> a
error String
"the name of this Expr has already been set"
    Maybe (ForeignPtr CxxParameter)
Nothing -> do
      ForeignPtr CxxParameter
fp <- forall a.
IsHalideType a =>
Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter @a (forall a. a -> Maybe a
Just Text
name)
      forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxParameter))
r (forall a. a -> Maybe a
Just ForeignPtr CxxParameter
fp)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k).
IORef (Maybe (ForeignPtr CxxParameter)) -> Expr a
ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r)
scalar Text
_ Expr a
_ = forall a. HasCallStack => String -> a
error String
"cannot set the name of an expression that is not a parameter"

wrapCxxStage :: Ptr CxxStage -> IO (Stage n a)
wrapCxxStage :: forall (n :: Nat) a. Ptr CxxStage -> IO (Stage n a)
wrapCxxStage = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat) a. ForeignPtr CxxStage -> Stage n a
Stage forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxStage -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxStage -> IO ())
deleter = [C.funPtr| void deleteStage(Halide::Stage* p) { delete p; } |]

withCxxStage :: Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage :: forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage (Stage ForeignPtr CxxStage
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxStage
fp

-- | Get the pure stage of a 'Func' for the purposes of scheduling it.
getStage :: KnownNat n => Func t n a -> IO (Stage n a)
getStage :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
    [CU.exp| Halide::Stage* { new Halide::Stage{static_cast<Halide::Stage>(*$(Halide::Func* func'))} } |]
      forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (n :: Nat) a. Ptr CxxStage -> IO (Stage n a)
wrapCxxStage

-- | Return 'True' when the function has update definitions, 'False' otherwise.
hasUpdateDefinitions :: KnownNat n => Func t n a -> IO Bool
hasUpdateDefinitions :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO Bool
hasUpdateDefinitions Func t n a
func =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
    forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Func* func')->has_update_definition() } |]

-- | Get a handle to an update step for the purposes of scheduling it.
getUpdateStage :: KnownNat n => Int -> Func 'FuncTy n a -> IO (Stage n a)
getUpdateStage :: forall (n :: Nat) a.
KnownNat n =>
Int -> Func 'FuncTy n a -> IO (Stage n a)
getUpdateStage Int
k Func 'FuncTy n a
func =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
    let k' :: CInt
k' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
     in [CU.exp| Halide::Stage* { new Halide::Stage{$(Halide::Func* func')->update($(int k'))} } |]
          forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (n :: Nat) a. Ptr CxxStage -> IO (Stage n a)
wrapCxxStage

-- | Identify the loop nest corresponding to some dimension of some function.
getLoopLevelAtStage
  :: KnownNat n
  => Func t n a
  -> Expr Int32
  -> Int
  -- ^ update index
  -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> VarOrRVar -> Int -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage Func t n a
func VarOrRVar
var Int
stageIndex =
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
i -> do
    (SomeLoopLevel LoopLevel t
level) <-
      Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::LoopLevel* {
              return handle_halide_exceptions([=](){
                return new Halide::LoopLevel{*$(const Halide::Func* f),
                                             *$(const Halide::VarOrRVar* i),
                                             $(int k)};
              });
            } |]
    case LoopLevel t
level of
      LoopLevel ForeignPtr CxxLoopLevel
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopLevel t
level
      LoopLevel t
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getLoopLevelAtStage: got " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show LoopLevel t
level forall a. Semigroup a => a -> a -> a
<> String
", but expected a LoopLevel 'LockedTy"
  where
    k :: CInt
k = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
stageIndex

-- | Same as 'getLoopLevelAtStage' except that the stage is @-1@.
getLoopLevel :: KnownNat n => Func t n a -> Expr Int32 -> IO (LoopLevel 'LockedTy)
getLoopLevel :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> VarOrRVar -> IO (LoopLevel 'LockedTy)
getLoopLevel Func t n a
f VarOrRVar
i = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> VarOrRVar -> Int -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage Func t n a
f VarOrRVar
i (-Int
1)

-- | Allocate storage for this function within a particular loop level.
--
-- Scheduling storage is optional, and can be used to separate the loop level at which storage is allocated
-- from the loop level at which computation occurs to trade off between locality and redundant work.
--
-- For more info, see [Halide::Func::store_at](https://halide-lang.org/docs/class_halide_1_1_func.html#a417c08f8aa3a5cdf9146fba948b65193).
storeAt :: KnownNat n => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
storeAt :: forall (n :: Nat) a (t :: LoopLevelTy).
KnownNat n =>
Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
storeAt Func 'FuncTy n a
func LoopLevel t
level = do
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l ->
      [CU.exp| void { $(Halide::Func* f)->store_at(*$(const Halide::LoopLevel* l)) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func

-- | Schedule a function to be computed within the iteration over a given loop level.
--
-- For more info, see [Halide::Func::compute_at](https://halide-lang.org/docs/class_halide_1_1_func.html#a800cbcc3ca5e3d3fa1707f6e1990ec83).
computeAt :: KnownNat n => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
computeAt :: forall (n :: Nat) a (t :: LoopLevelTy).
KnownNat n =>
Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
computeAt Func 'FuncTy n a
func LoopLevel t
level = do
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l ->
      [CU.exp| void { $(Halide::Func* f)->compute_at(*$(const Halide::LoopLevel* l)) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func

-- | Wrap a buffer into a t'Func'.
--
-- Suppose, we are defining a pipeline that adds together two vectors, and we'd like to call 'realize' to
-- evaluate it directly, how do we pass the vectors to the t'Func'? 'asBufferParam' allows to do exactly this.
--
-- > asBuffer [1, 2, 3] $ \a ->
-- >   asBuffer [4, 5, 6] $ \b -> do
-- >     i <- mkVar "i"
-- >     f <- define "vectorAdd" i $ a ! i + b ! i
-- >     realize f [3] $ \result ->
-- >       print =<< peekToList f
asBufferParam
  :: forall n a t b
   . IsHalideBuffer t n a
  => t
  -- ^ Object to treat as a buffer
  -> (Func 'ParamTy n (Expr a) -> IO b)
  -- ^ What to do with the __temporary__ buffer
  -> IO b
asBufferParam :: forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Func 'ParamTy n (Expr a) -> IO b) -> IO b
asBufferParam t
arr Func 'ParamTy n (Expr a) -> IO b
action =
  forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer @n @a t
arr forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
arr' -> do
    ForeignPtr CxxImageParam
param <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter @n @a forall a. Maybe a
Nothing
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxImageParam
param forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
param' ->
      let buf :: Ptr RawHalideBuffer
buf = (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
arr' :: Ptr RawHalideBuffer)
       in [CU.block| void {
            $(Halide::ImageParam* param')->set(Halide::Buffer<>{*$(const halide_buffer_t* buf)});
          } |]
    Func 'ParamTy n (Expr a) -> IO b
action forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (n :: Nat).
IsHalideType a =>
IORef (Maybe (ForeignPtr CxxImageParam))
-> Func 'ParamTy n (Expr a)
Param forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. a -> IO (IORef a)
newIORef (forall a. a -> Maybe a
Just ForeignPtr CxxImageParam
param)