{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      : Language.Halide.Buffer
-- Description : Buffers
-- Copyright   : (c) Tom Westerhout, 2021-2023
--
-- A buffer in Halide is a __view__ of some multidimensional array. Buffers can reference data that's
-- located on a CPU, GPU, or another device. Halide pipelines use buffers for both input and output arguments.
module Language.Halide.Buffer
  ( -- * Buffers

  --

    -- | In the C interface of Halide, buffers are described by the C struct
    -- [@halide_buffer_t@](https://halide-lang.org/docs/structhalide__buffer__t.html). On the Haskell side,
    -- we have 'HalideBuffer'.
    HalideBuffer (..)
    -- | To easily test out your pipeline, there are helper functions to create 'HalideBuffer's without
    -- worrying about the low-level representation.
  , allocaCpuBuffer
  , allocaBuffer
    -- | Buffers can also be converted to lists to easily print them for debugging.
  , IsListPeek (..)
  , peekScalar
    -- | For production usage however, you don't want to work with lists. Instead, you probably want Halide
    -- to work with your existing array data types. For this, we define 'IsHalideBuffer' typeclass that
    -- teaches Halide how to convert your data into a 'HalideBuffer'. Depending on how you implement the
    -- instance, this can be very efficient, because it need not involve any memory copying.
  , IsHalideBuffer (..)
  , withHalideBuffer
    -- | There are also helper functions to simplify writing instances of 'IsHalideBuffer'.
  , bufferFromPtrShapeStrides
  , bufferFromPtrShape

    -- * Internals
  , RawHalideBuffer (..)
  , HalideDimension (..)
  , HalideDeviceInterface
  , rowMajorStrides
  , colMajorStrides
  , isDeviceDirty
  , isHostDirty
  , getBufferExtent
  , bufferCopyToHost
  , withCopiedToHost
  , withCropped
  )
where

import Control.Exception (bracket_)
import Control.Monad (forM, unless, when)
import Control.Monad.ST (RealWorld)
import Data.Int
import Data.Kind (Type)
import Data.List qualified as List
import Data.Proxy
import Data.Vector.Storable qualified as S
import Data.Vector.Storable.Mutable qualified as SM
import Data.Word
import Foreign.Marshal.Alloc (alloca, free, mallocBytes)
import Foreign.Marshal.Array
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import GHC.Stack (HasCallStack)
import GHC.TypeNats
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Context
import Language.Halide.Target
import Language.Halide.Type
import Prelude hiding (min)

-- | Information about a dimension in a buffer.
--
-- It is the Haskell analogue of [@halide_dimension_t@](https://halide-lang.org/docs/structhalide__dimension__t.html).
data HalideDimension = HalideDimension
  { HalideDimension -> Int32
halideDimensionMin :: {-# UNPACK #-} !Int32
  -- ^ Starting index.
  , HalideDimension -> Int32
halideDimensionExtent :: {-# UNPACK #-} !Int32
  -- ^ Length of the dimension.
  , HalideDimension -> Int32
halideDimensionStride :: {-# UNPACK #-} !Int32
  -- ^ Stride along this dimension.
  , HalideDimension -> Word32
halideDimensionFlags :: {-# UNPACK #-} !Word32
  -- ^ Extra flags.
  }
  deriving stock (ReadPrec [HalideDimension]
ReadPrec HalideDimension
Int -> ReadS HalideDimension
ReadS [HalideDimension]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideDimension]
$creadListPrec :: ReadPrec [HalideDimension]
readPrec :: ReadPrec HalideDimension
$creadPrec :: ReadPrec HalideDimension
readList :: ReadS [HalideDimension]
$creadList :: ReadS [HalideDimension]
readsPrec :: Int -> ReadS HalideDimension
$creadsPrec :: Int -> ReadS HalideDimension
Read, Int -> HalideDimension -> ShowS
[HalideDimension] -> ShowS
HalideDimension -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HalideDimension] -> ShowS
$cshowList :: [HalideDimension] -> ShowS
show :: HalideDimension -> String
$cshow :: HalideDimension -> String
showsPrec :: Int -> HalideDimension -> ShowS
$cshowsPrec :: Int -> HalideDimension -> ShowS
Show, HalideDimension -> HalideDimension -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideDimension -> HalideDimension -> Bool
$c/= :: HalideDimension -> HalideDimension -> Bool
== :: HalideDimension -> HalideDimension -> Bool
$c== :: HalideDimension -> HalideDimension -> Bool
Eq)

instance Storable HalideDimension where
  sizeOf :: HalideDimension -> Int
sizeOf HalideDimension
_ = Int
16
  {-# INLINE sizeOf #-}
  alignment :: HalideDimension -> Int
alignment HalideDimension
_ = Int
4
  {-# INLINE alignment #-}
  peek :: Ptr HalideDimension -> IO HalideDimension
peek Ptr HalideDimension
p =
    Int32 -> Int32 -> Int32 -> Word32 -> HalideDimension
HalideDimension
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
0
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
4
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
8
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
12
  {-# INLINE peek #-}
  poke :: Ptr HalideDimension -> HalideDimension -> IO ()
poke Ptr HalideDimension
p HalideDimension
x = do
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
0 (HalideDimension -> Int32
halideDimensionMin HalideDimension
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
4 (HalideDimension -> Int32
halideDimensionExtent HalideDimension
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
8 (HalideDimension -> Int32
halideDimensionStride HalideDimension
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
12 (HalideDimension -> Word32
halideDimensionFlags HalideDimension
x)
  {-# INLINE poke #-}

-- | @simpleDimension extent stride@ creates a @HalideDimension@ of size @extent@ separated by
-- @stride@.
simpleDimension :: Int -> Int -> HalideDimension
simpleDimension :: Int -> Int -> HalideDimension
simpleDimension Int
extent Int
stride = Int32 -> Int32 -> Int32 -> Word32 -> HalideDimension
HalideDimension Int32
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
extent) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
stride) Word32
0
{-# INLINE simpleDimension #-}

-- | Get strides corresponding to row-major ordering
rowMajorStrides
  :: Integral a
  => [a]
  -- ^ Extents
  -> [a]
rowMajorStrides :: forall a. Integral a => [a] -> [a]
rowMajorStrides = forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) a
1

-- | Get strides corresponding to column-major ordering.
colMajorStrides
  :: Integral a
  => [a]
  -- ^ Extents
  -> [a]
colMajorStrides :: forall a. Integral a => [a] -> [a]
colMajorStrides = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) a
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
init

-- | Haskell analogue of [@halide_device_interface_t@](https://halide-lang.org/docs/structhalide__device__interface__t.html).
data HalideDeviceInterface

-- | The low-level untyped Haskell analogue of [@halide_buffer_t@](https://halide-lang.org/docs/structhalide__buffer__t.html).
--
-- It's quite difficult to use 'RawHalideBuffer' correctly, and misusage can result in crashes and
-- segmentation faults. Hence, prefer the higher-level 'HalideBuffer' wrapper for all your code
data RawHalideBuffer = RawHalideBuffer
  { RawHalideBuffer -> Word64
halideBufferDevice :: !Word64
  , RawHalideBuffer -> Ptr HalideDeviceInterface
halideBufferDeviceInterface :: !(Ptr HalideDeviceInterface)
  , RawHalideBuffer -> Ptr Word8
halideBufferHost :: !(Ptr Word8)
  , RawHalideBuffer -> Word64
halideBufferFlags :: !Word64
  , RawHalideBuffer -> HalideType
halideBufferType :: !HalideType
  , RawHalideBuffer -> Int32
halideBufferDimensions :: !Int32
  , RawHalideBuffer -> Ptr HalideDimension
halideBufferDim :: !(Ptr HalideDimension)
  , RawHalideBuffer -> Ptr ()
halideBufferPadding :: !(Ptr ())
  }
  deriving stock (Int -> RawHalideBuffer -> ShowS
[RawHalideBuffer] -> ShowS
RawHalideBuffer -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RawHalideBuffer] -> ShowS
$cshowList :: [RawHalideBuffer] -> ShowS
show :: RawHalideBuffer -> String
$cshow :: RawHalideBuffer -> String
showsPrec :: Int -> RawHalideBuffer -> ShowS
$cshowsPrec :: Int -> RawHalideBuffer -> ShowS
Show, RawHalideBuffer -> RawHalideBuffer -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RawHalideBuffer -> RawHalideBuffer -> Bool
$c/= :: RawHalideBuffer -> RawHalideBuffer -> Bool
== :: RawHalideBuffer -> RawHalideBuffer -> Bool
$c== :: RawHalideBuffer -> RawHalideBuffer -> Bool
Eq)

-- | An @n@-dimensional buffer of elements of type @a@.
--
-- Most pipelines use @'Ptr' ('HalideBuffer' n a)@ for input and output array arguments.
newtype HalideBuffer (n :: Nat) (a :: Type) = HalideBuffer {forall (n :: Nat) a. HalideBuffer n a -> RawHalideBuffer
unHalideBuffer :: RawHalideBuffer}
  deriving stock (Int -> HalideBuffer n a -> ShowS
forall (n :: Nat) a. Int -> HalideBuffer n a -> ShowS
forall (n :: Nat) a. [HalideBuffer n a] -> ShowS
forall (n :: Nat) a. HalideBuffer n a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HalideBuffer n a] -> ShowS
$cshowList :: forall (n :: Nat) a. [HalideBuffer n a] -> ShowS
show :: HalideBuffer n a -> String
$cshow :: forall (n :: Nat) a. HalideBuffer n a -> String
showsPrec :: Int -> HalideBuffer n a -> ShowS
$cshowsPrec :: forall (n :: Nat) a. Int -> HalideBuffer n a -> ShowS
Show, HalideBuffer n a -> HalideBuffer n a -> Bool
forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideBuffer n a -> HalideBuffer n a -> Bool
$c/= :: forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
== :: HalideBuffer n a -> HalideBuffer n a -> Bool
$c== :: forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
Eq)

importHalide

instance Storable RawHalideBuffer where
  sizeOf :: RawHalideBuffer -> Int
sizeOf RawHalideBuffer
_ = Int
56
  alignment :: RawHalideBuffer -> Int
alignment RawHalideBuffer
_ = Int
8
  peek :: Ptr RawHalideBuffer -> IO RawHalideBuffer
peek Ptr RawHalideBuffer
p =
    Word64
-> Ptr HalideDeviceInterface
-> Ptr Word8
-> Word64
-> HalideType
-> Int32
-> Ptr HalideDimension
-> Ptr ()
-> RawHalideBuffer
RawHalideBuffer
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
0 -- device
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
8 -- interface
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
16 -- host
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
24 -- flags
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
32 -- type
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
36 -- dimensions
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
40 -- dim
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
48 -- padding
  poke :: Ptr RawHalideBuffer -> RawHalideBuffer -> IO ()
poke Ptr RawHalideBuffer
p RawHalideBuffer
x = do
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
0 (RawHalideBuffer -> Word64
halideBufferDevice RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
8 (RawHalideBuffer -> Ptr HalideDeviceInterface
halideBufferDeviceInterface RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
16 (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
24 (RawHalideBuffer -> Word64
halideBufferFlags RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
32 (RawHalideBuffer -> HalideType
halideBufferType RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
36 (RawHalideBuffer -> Int32
halideBufferDimensions RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
40 (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
48 (RawHalideBuffer -> Ptr ()
halideBufferPadding RawHalideBuffer
x)

-- | Construct a 'HalideBuffer' from a pointer to the data, a list of extents,
-- and a list of strides, and use it in an 'IO' action.
--
-- This function throws a runtime error if the number of dimensions does not
-- match @n@.
bufferFromPtrShapeStrides
  :: forall n a b
   . (HasCallStack, KnownNat n, IsHalideType a)
  => Ptr a
  -- ^ CPU pointer to the data
  -> [Int]
  -- ^ Extents (in number of elements, __not__ in bytes)
  -> [Int]
  -- ^ Strides (in number of elements, __not__ in bytes)
  -> (Ptr (HalideBuffer n a) -> IO b)
  -- ^ Action to run
  -> IO b
bufferFromPtrShapeStrides :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShapeStrides Ptr a
p [Int]
shape [Int]
stride Ptr (HalideBuffer n a) -> IO b
action =
  forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> HalideDimension
simpleDimension [Int]
shape [Int]
stride) forall a b. (a -> b) -> a -> b
$ \Int
n Ptr HalideDimension
dim -> do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"specified wrong number of dimensions: "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n
          forall a. Semigroup a => a -> a -> a
<> String
"; expected "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
          forall a. Semigroup a => a -> a -> a
<> String
" from the type declaration"
    let !buffer :: RawHalideBuffer
buffer =
          RawHalideBuffer
            { halideBufferDevice :: Word64
halideBufferDevice = Word64
0
            , halideBufferDeviceInterface :: Ptr HalideDeviceInterface
halideBufferDeviceInterface = forall a. Ptr a
nullPtr
            , halideBufferHost :: Ptr Word8
halideBufferHost = forall a b. Ptr a -> Ptr b
castPtr Ptr a
p
            , halideBufferFlags :: Word64
halideBufferFlags = Word64
0
            , halideBufferType :: HalideType
halideBufferType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
            , halideBufferDimensions :: Int32
halideBufferDimensions = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
            , halideBufferDim :: Ptr HalideDimension
halideBufferDim = Ptr HalideDimension
dim
            , halideBufferPadding :: Ptr ()
halideBufferPadding = forall a. Ptr a
nullPtr
            }
    forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with RawHalideBuffer
buffer forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
bufferPtr -> do
      b
r <- Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
bufferPtr)
      Bool
hasDataOnDevice <-
        forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* bufferPtr)->device } |]
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
hasDataOnDevice forall a b. (a -> b) -> a -> b
$
        forall a. HasCallStack => String -> a
error String
"the Buffer still references data on the device; did you forget to call copyToHost?"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r

-- | Similar to 'bufferFromPtrShapeStrides', but assumes column-major ordering of data.
bufferFromPtrShape
  :: (HasCallStack, KnownNat n, IsHalideType a)
  => Ptr a
  -- ^ CPU pointer to the data
  -> [Int]
  -- ^ Extents (in number of elements, __not__ in bytes)
  -> (Ptr (HalideBuffer n a) -> IO b)
  -> IO b
bufferFromPtrShape :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
p [Int]
shape = forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShapeStrides Ptr a
p [Int]
shape (forall a. Integral a => [a] -> [a]
colMajorStrides [Int]
shape)

-- | Specifies that a type @t@ can be used as an @n@-dimensional Halide buffer with elements of type @a@.
class (KnownNat n, IsHalideType a) => IsHalideBuffer t n a where
  withHalideBufferImpl :: t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b

-- | Treat a type @t@ as a 'HalideBuffer' and use it in an 'IO' action.
--
-- This function is a simple wrapper around 'withHalideBufferImpl', except that the order of type parameters
-- is reversed. If you have @TypeApplications@ extension enabled, this allows you to write
-- @withHalideBuffer @3 @Float yourBuffer@ to specify that you want a 3-dimensional buffer of @Float@.
withHalideBuffer :: forall n a t b. IsHalideBuffer t n a => t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer :: forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer = forall t (n :: Nat) a b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBufferImpl @t @n @a

-- | Storable vectors are one-dimensional buffers. This involves no copying.
instance IsHalideType a => IsHalideBuffer (S.Vector a) 1 a where
  withHalideBufferImpl :: forall b. Vector a -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl Vector a
v Ptr (HalideBuffer 1 a) -> IO b
f =
    forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
dataPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
dataPtr [forall a. Storable a => Vector a -> Int
S.length Vector a
v] Ptr (HalideBuffer 1 a) -> IO b
f

-- | Storable vectors are one-dimensional buffers. This involves no copying.
instance IsHalideType a => IsHalideBuffer (S.MVector RealWorld a) 1 a where
  withHalideBufferImpl :: forall b.
MVector RealWorld a -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl MVector RealWorld a
v Ptr (HalideBuffer 1 a) -> IO b
f =
    forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
SM.unsafeWith MVector RealWorld a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
dataPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
dataPtr [forall a s. Storable a => MVector s a -> Int
SM.length MVector RealWorld a
v] Ptr (HalideBuffer 1 a) -> IO b
f

-- | Lists can also act as Halide buffers. __Use for testing only.__
instance IsHalideType a => IsHalideBuffer [a] 1 a where
  withHalideBufferImpl :: forall b. [a] -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl [a]
v = forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer (forall a. Storable a => [a] -> Vector a
S.fromList [a]
v)

-- | Lists can also act as Halide buffers. __Use for testing only.__
instance IsHalideType a => IsHalideBuffer [[a]] 2 a where
  withHalideBufferImpl :: forall b. [[a]] -> (Ptr (HalideBuffer 2 a) -> IO b) -> IO b
withHalideBufferImpl [[a]]
xs Ptr (HalideBuffer 2 a) -> IO b
f = do
    let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[a]]
xs
        d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[a]]
xs)
        -- we want column-major ordering, so transpose first
        v :: Vector a
v = forall a. Storable a => [a] -> Vector a
S.fromList (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
List.concat (forall a. [[a]] -> [[a]]
List.transpose [[a]]
xs))
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
    forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1] Ptr (HalideBuffer 2 a) -> IO b
f

-- | Lists can also act as Halide buffers. __Use for testing only.__
instance IsHalideType a => IsHalideBuffer [[[a]]] 3 a where
  withHalideBufferImpl :: forall b. [[[a]]] -> (Ptr (HalideBuffer 3 a) -> IO b) -> IO b
withHalideBufferImpl [[[a]]]
xs Ptr (HalideBuffer 3 a) -> IO b
f = do
    let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[[a]]]
xs
        d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[[a]]]
xs)
        d2 :: Int
d2 = if Int
d1 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head [[[a]]]
xs))
        -- we want column-major ordering, so transpose first
        v :: Vector a
v =
          forall a. Storable a => [a] -> Vector a
S.fromList
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
List.concat
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
List.concatMap forall a. [[a]] -> [[a]]
List.transpose
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose
            forall a b. (a -> b) -> a -> b
$ [[[a]]]
xs
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1 forall a. Num a => a -> a -> a
* Int
d2) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
    forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1, Int
d2] Ptr (HalideBuffer 3 a) -> IO b
f

-- | Lists can also act as Halide buffers. __Use for testing only.__
instance IsHalideType a => IsHalideBuffer [[[[a]]]] 4 a where
  withHalideBufferImpl :: forall b. [[[[a]]]] -> (Ptr (HalideBuffer 4 a) -> IO b) -> IO b
withHalideBufferImpl [[[[a]]]]
xs Ptr (HalideBuffer 4 a) -> IO b
f = do
    let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[[[a]]]]
xs
        d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[[[a]]]]
xs)
        d2 :: Int
d2 = if Int
d1 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head [[[[a]]]]
xs))
        d3 :: Int
d3 = if Int
d2 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head (forall a. [a] -> a
head [[[[a]]]]
xs)))
        -- we want column-major ordering, so transpose first
        v :: Vector a
v =
          forall a. Storable a => [a] -> Vector a
S.fromList
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose)
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose)
            forall a b. (a -> b) -> a -> b
$ [[[[a]]]]
xs
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1 forall a. Num a => a -> a -> a
* Int
d2 forall a. Num a => a -> a -> a
* Int
d3) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
    forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1, Int
d2, Int
d3] Ptr (HalideBuffer 4 a) -> IO b
f

whenM :: Monad m => m Bool -> m () -> m ()
whenM :: forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM m Bool
cond m ()
f =
  m Bool
cond forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> m ()
f
    Bool
False -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Temporary allocate a CPU buffer.
--
-- This is useful for testing and debugging when you need to allocate an output buffer for your pipeline. E.g.
--
-- @
-- 'allocaCpuBuffer' [3, 3] $ \out -> do
--   myKernel out                -- fill the buffer
--   print =<< 'peekToList' out  -- print it for debugging
-- @
allocaCpuBuffer
  :: forall n a b
   . (HasCallStack, KnownNat n, IsHalideType a)
  => [Int]
  -> (Ptr (HalideBuffer n a) -> IO b)
  -> IO b
allocaCpuBuffer :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
[Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaCpuBuffer = forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Target -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaBuffer Target
hostTarget

getTotalBytes :: Ptr RawHalideBuffer -> IO Int
getTotalBytes :: Ptr RawHalideBuffer -> IO Int
getTotalBytes Ptr RawHalideBuffer
buf = do
  forall a b. (Integral a, Num b) => a -> b
fromIntegral
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.block| size_t {
          auto const& b = *$(const halide_buffer_t* buf);
          auto const n = std::accumulate(b.dim, b.dim + b.dimensions, size_t{1},
                                         [](auto acc, auto const& dim) { return acc * dim.extent; });
          return n * (b.type.bits * b.type.lanes / 8);
        } |]

allocateHostMemory :: Ptr RawHalideBuffer -> IO ()
allocateHostMemory :: Ptr RawHalideBuffer -> IO ()
allocateHostMemory Ptr RawHalideBuffer
buf = do
  Ptr Word8
ptr <- forall a. Int -> IO (Ptr a)
mallocBytes forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr RawHalideBuffer -> IO Int
getTotalBytes Ptr RawHalideBuffer
buf
  [CU.block| void { $(halide_buffer_t* buf)->host = $(uint8_t* ptr); } |]

freeHostMemory :: Ptr RawHalideBuffer -> IO ()
freeHostMemory :: Ptr RawHalideBuffer -> IO ()
freeHostMemory Ptr RawHalideBuffer
buf = do
  Ptr Word8
ptr <-
    [CU.block| uint8_t* {
      auto& b = *$(halide_buffer_t* buf);
      auto const p = b.host;
      b.host = nullptr;
      return p;
    } |]
  forall a. Ptr a -> IO ()
free Ptr Word8
ptr

allocateDeviceMemory :: Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory :: Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory Ptr HalideDeviceInterface
interface Ptr RawHalideBuffer
buf = do
  [CU.block| void {
    auto const* interface = $(const halide_device_interface_t* interface);
    interface->device_malloc(nullptr, $(halide_buffer_t* buf), interface);
  } |]

freeDeviceMemory :: HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory :: HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory Ptr RawHalideBuffer
buf = do
  Ptr HalideDeviceInterface
deviceInterface <-
    [CU.exp| const halide_device_interface_t* { $(const halide_buffer_t* buf)->device_interface } |]
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr HalideDeviceInterface
deviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall a b. (a -> b) -> a -> b
$
    forall a. HasCallStack => String -> a
error String
"cannot free device memory: device_interface is NULL"
  [CU.block| void {
    $(halide_buffer_t* buf)->device_interface->device_free(nullptr, $(halide_buffer_t* buf));
    $(halide_buffer_t* buf)->device = 0;
  } |]

allocaBuffer
  :: forall n a b
   . (HasCallStack, KnownNat n, IsHalideType a)
  => Target
  -> [Int]
  -> (Ptr (HalideBuffer n a) -> IO b)
  -> IO b
allocaBuffer :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Target -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaBuffer Target
target [Int]
shape Ptr (HalideBuffer n a) -> IO b
action = do
  Ptr HalideDeviceInterface
deviceInterface <- Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface Target
target
  let onHost :: Bool
onHost = Ptr HalideDeviceInterface
deviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr
  forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> HalideDimension
simpleDimension [Int]
shape (forall a. Integral a => [a] -> [a]
colMajorStrides [Int]
shape)) forall a b. (a -> b) -> a -> b
$ \Int
n Ptr HalideDimension
dim -> do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"specified wrong number of dimensions: "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n
          forall a. Semigroup a => a -> a -> a
<> String
"; expected "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
          forall a. Semigroup a => a -> a -> a
<> String
" from the type declaration"
    let rawBuffer :: RawHalideBuffer
rawBuffer =
          RawHalideBuffer
            { halideBufferDevice :: Word64
halideBufferDevice = Word64
0
            , halideBufferDeviceInterface :: Ptr HalideDeviceInterface
halideBufferDeviceInterface = forall a. Ptr a
nullPtr
            , halideBufferHost :: Ptr Word8
halideBufferHost = forall a. Ptr a
nullPtr
            , halideBufferFlags :: Word64
halideBufferFlags = Word64
0
            , halideBufferType :: HalideType
halideBufferType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
            , halideBufferDimensions :: Int32
halideBufferDimensions = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
            , halideBufferDim :: Ptr HalideDimension
halideBufferDim = Ptr HalideDimension
dim
            , halideBufferPadding :: Ptr ()
halideBufferPadding = forall a. Ptr a
nullPtr
            }
    forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with RawHalideBuffer
rawBuffer forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
buf -> do
      let allocate :: Ptr RawHalideBuffer -> IO ()
allocate
            | Bool
onHost = Ptr RawHalideBuffer -> IO ()
allocateHostMemory
            | Bool
otherwise = Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory Ptr HalideDeviceInterface
deviceInterface
      let deallocate :: Ptr RawHalideBuffer -> IO ()
deallocate
            | Bool
onHost = Ptr RawHalideBuffer -> IO ()
freeHostMemory
            | Bool
otherwise = HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory
      forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (Ptr RawHalideBuffer -> IO ()
allocate Ptr RawHalideBuffer
buf) (Ptr RawHalideBuffer -> IO ()
deallocate Ptr RawHalideBuffer
buf) forall a b. (a -> b) -> a -> b
$ do
        b
r <- Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
buf)
        Bool
isHostNull <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* buf)->host == nullptr } |]
        Bool
isDeviceNull <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* buf)->device == 0 } |]
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
onHost Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isDeviceNull) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
          String
"buffer was allocated on host, but its device pointer is not NULL"
            forall a. Semigroup a => a -> a -> a
<> String
"; did you forget a copyToHost in your pipeline?"
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
onHost Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isHostNull) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
          String
"buffer was allocated on device, but its host pointer is not NULL"
            forall a. Semigroup a => a -> a -> a
<> String
"; did you add an extra copyToHost?"
        forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r

getDeviceInterface :: Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface :: Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface Target
target =
  case DeviceAPI
device of
    DeviceAPI
DeviceNone -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Ptr a
nullPtr
    DeviceAPI
DeviceHost -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Ptr a
nullPtr
    DeviceAPI
_ ->
      forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
        [C.throwBlock| const halide_device_interface_t* {
          return handle_halide_exceptions([=](){
            auto const device = static_cast<Halide::DeviceAPI>($(int api));
            auto const& target = *$(const Halide::Target* target');
            return Halide::get_device_interface_for_device_api(device, target, "getDeviceInterface");
          });
        } |]
  where
    device :: DeviceAPI
device@(forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) = Target -> DeviceAPI
deviceAPIForTarget Target
target

-- | Do we have changes on the device the have not been copied to the host?
isDeviceDirty :: Ptr RawHalideBuffer -> IO Bool
isDeviceDirty :: Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p =
  forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const halide_buffer_t* p)->device_dirty() } |]

-- | Set the @device_dirty@ flag to the given value.
setDeviceDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CBool
b) Ptr RawHalideBuffer
p =
  [CU.exp| void { $(halide_buffer_t* p)->set_device_dirty($(bool b)) } |]

-- | Do we have changes on the device the have not been copied to the host?
isHostDirty :: Ptr RawHalideBuffer -> IO Bool
isHostDirty :: Ptr RawHalideBuffer -> IO Bool
isHostDirty Ptr RawHalideBuffer
p =
  forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const halide_buffer_t* p)->host_dirty() } |]

-- | Set the @host_dirty@ flag to the given value.
setHostDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setHostDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setHostDirty (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CBool
b) Ptr RawHalideBuffer
p =
  [CU.exp| void { $(halide_buffer_t* p)->set_host_dirty($(bool b)) } |]

-- | Copy the underlying memory from device to host.
bufferCopyToHost :: HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost :: HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost Ptr RawHalideBuffer
p = forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p) forall a b. (a -> b) -> a -> b
$ do
  RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek Ptr RawHalideBuffer
p
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDeviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"device_dirty is set, but device_interface is NULL"
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferHost forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"host is NULL, did you forget to allocate memory?"
  [CU.block| void {
    auto& buf = *$(halide_buffer_t* p);
    buf.device_interface->copy_to_host(nullptr, &buf);
  } |]
  forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"device_dirty is set right after a copy_to_host; something went wrong..."

checkNumberOfDimensions :: forall n. (HasCallStack, KnownNat n) => RawHalideBuffer -> IO ()
checkNumberOfDimensions :: forall (n :: Nat).
(HasCallStack, KnownNat n) =>
RawHalideBuffer -> IO ()
checkNumberOfDimensions RawHalideBuffer
raw = do
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) forall a. Eq a => a -> a -> Bool
== RawHalideBuffer
raw.halideBufferDimensions) forall a b. (a -> b) -> a -> b
$
    forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
      String
"type-level and runtime number of dimensions do not match: "
        forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
        forall a. Semigroup a => a -> a -> a
<> String
" != "
        forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RawHalideBuffer
raw.halideBufferDimensions

-- | Perform an action on a cropped buffer.
withCropped
  :: Ptr (HalideBuffer n a)
  -- ^ buffer
  -> Int
  -- ^ dimension
  -> Int
  -- ^ min
  -> Int
  -- ^ extent
  -> (Ptr (HalideBuffer n a) -> IO b)
  -- ^ what to do
  -> IO b
withCropped :: forall (n :: Nat) a b.
Ptr (HalideBuffer n a)
-> Int -> Int -> Int -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withCropped
  (forall a b. Ptr a -> Ptr b
castPtr -> Ptr RawHalideBuffer
src)
  (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
d)
  (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
min)
  (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
extent)
  Ptr (HalideBuffer n a) -> IO b
action = do
    Int
rank <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_buffer_t* src)->dimensions } |]
    forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
dst ->
      forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
rank forall a b. (a -> b) -> a -> b
$ \Ptr HalideDimension
dstDim -> do
        [CU.block| void {
          auto const& src = *$(const halide_buffer_t* src);
          auto& dst = *$(halide_buffer_t* dst);
          auto const d = $(int d);

          dst = src;
          dst.dim = $(halide_dimension_t* dstDim);
          memcpy(dst.dim, src.dim, src.dimensions * sizeof(halide_dimension_t));

          if (dst.host != nullptr) {
            auto const shift = $(int min) - src.dim[d].min;
            dst.host += (shift * src.dim[d].stride) * ((src.type.bits + 7) / 8);
          }
          dst.dim[d].min = $(int min);
          dst.dim[d].extent = $(int extent);

          if (src.device != 0 && src.device_interface != nullptr) {
            src.device_interface->device_crop(nullptr, &src, &dst);
          }
        } |]
        Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
dst)

getBufferExtent :: forall n a. KnownNat n => Ptr (HalideBuffer n a) -> Int -> IO Int
getBufferExtent :: forall (n :: Nat) a.
KnownNat n =>
Ptr (HalideBuffer n a) -> Int -> IO Int
getBufferExtent (forall a b. Ptr a -> Ptr b
castPtr -> Ptr RawHalideBuffer
buf) (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
d)
  | CInt
d forall a. Ord a => a -> a -> Bool
< forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) =
      forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_buffer_t* buf)->dim[$(int d)].extent } |]
  | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"index out of bounds"

peekScalar :: forall a. (HasCallStack, IsHalideType a) => Ptr (HalideBuffer 0 a) -> IO a
peekScalar :: forall a.
(HasCallStack, IsHalideType a) =>
Ptr (HalideBuffer 0 a) -> IO a
peekScalar Ptr (HalideBuffer 0 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 0 a)
p forall a b. (a -> b) -> a -> b
$ do
  RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 0 a)
p)
  forall (n :: Nat).
(HasCallStack, KnownNat n) =>
RawHalideBuffer -> IO ()
checkNumberOfDimensions @0 RawHalideBuffer
raw
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferHost forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
  forall a. Storable a => Ptr a -> IO a
peek forall a b. (a -> b) -> a -> b
$ forall a b. Ptr a -> Ptr b
castPtr @_ @a RawHalideBuffer
raw.halideBufferHost

-- | Specifies that @a@ can be converted to a list. This is very similar to 'GHC.Exts.IsList' except that
-- we read the list from a @'Ptr'@ rather than converting directly.
-- class IsListPeek a where
--   type ListPeekElem a :: Type
--   peekToList :: HasCallStack => Ptr a -> IO [ListPeekElem a]
type family NestedList (n :: Nat) (a :: Type) where
  NestedList 0 a = a
  NestedList 1 a = [a]
  NestedList 2 a = [[a]]
  NestedList 3 a = [[[a]]]
  NestedList 4 a = [[[[a]]]]
  NestedList 5 a = [[[[[a]]]]]

type family NestedListLevel (a :: Type) :: Nat where
  NestedListLevel [a] = 1 + NestedListLevel a
  NestedListLevel a = 0

type family NestedListType (a :: Type) :: Type where
  NestedListType [a] = NestedListType a
  NestedListType a = a

class
  ( KnownNat n
  , IsHalideType a
  , NestedList n a ~ b
  , NestedListLevel b ~ n
  , NestedListType b ~ a
  ) =>
  IsListPeek n a b
    | n a -> b
    , n b -> a
    , a b -> n
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer n a) -> IO b

instance
  (IsHalideType a, NestedListLevel [a] ~ 1, NestedListType [a] ~ a)
  => IsListPeek 1 a [a]
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer 1 a) -> IO [a]
peekToList Ptr (HalideBuffer 1 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 1 a)
p forall a b. (a -> b) -> a -> b
$ do
    RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 1 a)
p)
    (HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
    let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 ->
      forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0))

instance
  (IsHalideType a, NestedListLevel [[a]] ~ 2, NestedListType [[a]] ~ a)
  => IsListPeek 2 a [[a]]
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer 2 a) -> IO [[a]]
peekToList Ptr (HalideBuffer 2 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 2 a)
p forall a b. (a -> b) -> a -> b
$ do
    RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 2 a)
p)
    (HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
    (HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
    let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
      let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 ->
        forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr1 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1))

instance
  (IsHalideType a, NestedListLevel [[[a]]] ~ 3, NestedListType [[[a]]] ~ a)
  => IsListPeek 3 a [[[a]]]
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer 3 a) -> IO [[[a]]]
peekToList Ptr (HalideBuffer 3 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 3 a)
p forall a b. (a -> b) -> a -> b
$ do
    RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 3 a)
p)
    (HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
    (HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
    (HalideDimension Int32
min2 Int32
extent2 Int32
stride2 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
2
    let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
      let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 -> do
        let ptr2 :: Ptr a
ptr2 = Ptr a
ptr1 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1)
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent2 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i2 ->
          forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr2 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min2 forall a. Num a => a -> a -> a
+ Int32
stride2 forall a. Num a => a -> a -> a
* Int32
i2))

instance
  (IsHalideType a, NestedListLevel [[[[a]]]] ~ 4, NestedListType [[[[a]]]] ~ a)
  => IsListPeek 4 a [[[[a]]]]
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer 4 a) -> IO [[[[a]]]]
peekToList Ptr (HalideBuffer 4 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 4 a)
p forall a b. (a -> b) -> a -> b
$ do
    RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 4 a)
p)
    (HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
    (HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
    (HalideDimension Int32
min2 Int32
extent2 Int32
stride2 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
2
    (HalideDimension Int32
min3 Int32
extent3 Int32
stride3 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
3
    let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
      let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 -> do
        let ptr2 :: Ptr a
ptr2 = Ptr a
ptr1 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1)
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent2 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i2 -> do
          let ptr3 :: Ptr a
ptr3 = Ptr a
ptr2 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min2 forall a. Num a => a -> a -> a
+ Int32
stride2 forall a. Num a => a -> a -> a
* Int32
i2)
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent3 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i3 ->
            forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr3 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min3 forall a. Num a => a -> a -> a
+ Int32
stride3 forall a. Num a => a -> a -> a
* Int32
i3))

-- | @withCopiedToHost buf action@ performs the action @action@ ensuring that @buf@ has been
-- copied to the host beforehand. If @buf@ is already on the host, no copying is performed.
withCopiedToHost :: Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost :: forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer -> Ptr RawHalideBuffer
buf) IO b
action = do
  RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek Ptr RawHalideBuffer
buf
  let allocate :: IO ()
allocate = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ Ptr RawHalideBuffer -> IO ()
allocateHostMemory Ptr RawHalideBuffer
buf
      deallocate :: IO ()
deallocate = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ Ptr RawHalideBuffer -> IO ()
freeHostMemory Ptr RawHalideBuffer
buf
  forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ IO ()
allocate IO ()
deallocate forall a b. (a -> b) -> a -> b
$ do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ do
      Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty Bool
True Ptr RawHalideBuffer
buf
      HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost Ptr RawHalideBuffer
buf
    IO b
action

-- instance IsHalideType a => IsListPeek (HalideBuffer 0 a) where
--   type ListPeekElem (HalideBuffer 0 a) = a
--   peekToList p = withCopiedToHost p $ do
--     raw <- peek (castPtr @_ @RawHalideBuffer p)
--     checkNumberOfDimensions @0 raw
--     when (raw.halideBufferHost == nullPtr) . error $ "host is NULL"
--     fmap pure . peek $ castPtr @_ @a raw.halideBufferHost

-- instance IsHalideType a => IsListPeek (HalideBuffer 1 a) where
--   type ListPeekElem (HalideBuffer 1 a) = a
--   peekToList p = withCopiedToHost p $ do
--     raw <- peek (castPtr @_ @RawHalideBuffer p)
--     (HalideDimension min0 extent0 stride0 _) <- peekElemOff (halideBufferDim raw) 0
--     let ptr0 = castPtr @_ @a (halideBufferHost raw)
--     when (ptr0 == nullPtr) . error $ "host is NULL"
--     forM [0 .. extent0 - 1] $ \i0 ->
--       peekElemOff ptr0 (fromIntegral (min0 + stride0 * i0))

-- instance IsHalideType a => IsListPeek (HalideBuffer 2 a) where
--   type ListPeekElem (HalideBuffer 2 a) = [a]
--   peekToList p = withCopiedToHost p $ do
--     raw <- peek (castPtr @_ @RawHalideBuffer p)
--     (HalideDimension min0 extent0 stride0 _) <- peekElemOff (halideBufferDim raw) 0
--     (HalideDimension min1 extent1 stride1 _) <- peekElemOff (halideBufferDim raw) 1
--     let ptr0 = castPtr @_ @a (halideBufferHost raw)
--     when (ptr0 == nullPtr) . error $ "host is NULL"
--     forM [0 .. extent0 - 1] $ \i0 -> do
--       let ptr1 = ptr0 `advancePtr` fromIntegral (min0 + stride0 * i0)
--       forM [0 .. extent1 - 1] $ \i1 ->
--         peekElemOff ptr1 (fromIntegral (min1 + stride1 * i1))

-- instance IsHalideType a => IsListPeek (HalideBuffer 3 a) where
--   type ListPeekElem (HalideBuffer 3 a) = [[a]]
--   peekToList p = withCopiedToHost p $ do
--     raw <- peek (castPtr @_ @RawHalideBuffer p)
--     (HalideDimension min0 extent0 stride0 _) <- peekElemOff (halideBufferDim raw) 0
--     (HalideDimension min1 extent1 stride1 _) <- peekElemOff (halideBufferDim raw) 1
--     (HalideDimension min2 extent2 stride2 _) <- peekElemOff (halideBufferDim raw) 2
--     let ptr0 = castPtr @_ @a (halideBufferHost raw)
--     when (ptr0 == nullPtr) . error $ "host is NULL"
--     forM [0 .. extent0 - 1] $ \i0 -> do
--       let ptr1 = ptr0 `advancePtr` fromIntegral (min0 + stride0 * i0)
--       forM [0 .. extent1 - 1] $ \i1 -> do
--         let ptr2 = ptr1 `advancePtr` fromIntegral (min1 + stride1 * i1)
--         forM [0 .. extent2 - 1] $ \i2 ->
--           peekElemOff ptr2 (fromIntegral (min2 + stride2 * i2))

-- instance IsHalideType a => IsListPeek (HalideBuffer 4 a) where
--   type ListPeekElem (HalideBuffer 4 a) = [[[a]]]
--   peekToList p = withCopiedToHost p $ do
--     raw <- peek (castPtr @_ @RawHalideBuffer p)
--     (HalideDimension min0 extent0 stride0 _) <- peekElemOff (halideBufferDim raw) 0
--     (HalideDimension min1 extent1 stride1 _) <- peekElemOff (halideBufferDim raw) 1
--     (HalideDimension min2 extent2 stride2 _) <- peekElemOff (halideBufferDim raw) 2
--     (HalideDimension min3 extent3 stride3 _) <- peekElemOff (halideBufferDim raw) 3
--     let ptr0 = castPtr @_ @a (halideBufferHost raw)
--     when (ptr0 == nullPtr) . error $ "host is NULL"
--     forM [0 .. extent0 - 1] $ \i0 -> do
--       let ptr1 = ptr0 `advancePtr` fromIntegral (min0 + stride0 * i0)
--       forM [0 .. extent1 - 1] $ \i1 -> do
--         let ptr2 = ptr1 `advancePtr` fromIntegral (min1 + stride1 * i1)
--         forM [0 .. extent2 - 1] $ \i2 -> do
--           let ptr3 = ptr2 `advancePtr` fromIntegral (min2 + stride2 * i2)
--           forM [0 .. extent3 - 1] $ \i3 ->
--             peekElemOff ptr3 (fromIntegral (min3 + stride3 * i3))