{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.Halide.Buffer
  ( 
  
    
    
    
    HalideBuffer (..)
    
    
  , allocaCpuBuffer
  , allocaBuffer
    
  , IsListPeek (..)
  , peekScalar
    
    
    
    
  , IsHalideBuffer (..)
  , withHalideBuffer
    
  , bufferFromPtrShapeStrides
  , bufferFromPtrShape
    
  , RawHalideBuffer (..)
  , HalideDimension (..)
  , HalideDeviceInterface
  , rowMajorStrides
  , colMajorStrides
  , isDeviceDirty
  , isHostDirty
  , getBufferExtent
  , bufferCopyToHost
  , withCopiedToHost
  , withCropped
  )
where
import Control.Exception (bracket_)
import Control.Monad (forM, unless, when)
import Control.Monad.ST (RealWorld)
import Data.Int
import Data.Kind (Type)
import Data.List qualified as List
import Data.Proxy
import Data.Vector.Storable qualified as S
import Data.Vector.Storable.Mutable qualified as SM
import Data.Word
import Foreign.Marshal.Alloc (alloca, free, mallocBytes)
import Foreign.Marshal.Array
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import GHC.Stack (HasCallStack)
import GHC.TypeNats
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Context
import Language.Halide.Target
import Language.Halide.Type
import Prelude hiding (min)
data HalideDimension = HalideDimension
  { HalideDimension -> Int32
halideDimensionMin :: {-# UNPACK #-} !Int32
  
  , HalideDimension -> Int32
halideDimensionExtent :: {-# UNPACK #-} !Int32
  
  , HalideDimension -> Int32
halideDimensionStride :: {-# UNPACK #-} !Int32
  
  , HalideDimension -> Word32
halideDimensionFlags :: {-# UNPACK #-} !Word32
  
  }
  deriving stock (ReadPrec [HalideDimension]
ReadPrec HalideDimension
Int -> ReadS HalideDimension
ReadS [HalideDimension]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideDimension]
$creadListPrec :: ReadPrec [HalideDimension]
readPrec :: ReadPrec HalideDimension
$creadPrec :: ReadPrec HalideDimension
readList :: ReadS [HalideDimension]
$creadList :: ReadS [HalideDimension]
readsPrec :: Int -> ReadS HalideDimension
$creadsPrec :: Int -> ReadS HalideDimension
Read, Int -> HalideDimension -> ShowS
[HalideDimension] -> ShowS
HalideDimension -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HalideDimension] -> ShowS
$cshowList :: [HalideDimension] -> ShowS
show :: HalideDimension -> String
$cshow :: HalideDimension -> String
showsPrec :: Int -> HalideDimension -> ShowS
$cshowsPrec :: Int -> HalideDimension -> ShowS
Show, HalideDimension -> HalideDimension -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideDimension -> HalideDimension -> Bool
$c/= :: HalideDimension -> HalideDimension -> Bool
== :: HalideDimension -> HalideDimension -> Bool
$c== :: HalideDimension -> HalideDimension -> Bool
Eq)
instance Storable HalideDimension where
  sizeOf :: HalideDimension -> Int
sizeOf HalideDimension
_ = Int
16
  {-# INLINE sizeOf #-}
  alignment :: HalideDimension -> Int
alignment HalideDimension
_ = Int
4
  {-# INLINE alignment #-}
  peek :: Ptr HalideDimension -> IO HalideDimension
peek Ptr HalideDimension
p =
    Int32 -> Int32 -> Int32 -> Word32 -> HalideDimension
HalideDimension
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
0
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
4
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
8
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
12
  {-# INLINE peek #-}
  poke :: Ptr HalideDimension -> HalideDimension -> IO ()
poke Ptr HalideDimension
p HalideDimension
x = do
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
0 (HalideDimension -> Int32
halideDimensionMin HalideDimension
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
4 (HalideDimension -> Int32
halideDimensionExtent HalideDimension
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
8 (HalideDimension -> Int32
halideDimensionStride HalideDimension
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
12 (HalideDimension -> Word32
halideDimensionFlags HalideDimension
x)
  {-# INLINE poke #-}
simpleDimension :: Int -> Int -> HalideDimension
simpleDimension :: Int -> Int -> HalideDimension
simpleDimension Int
extent Int
stride = Int32 -> Int32 -> Int32 -> Word32 -> HalideDimension
HalideDimension Int32
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
extent) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
stride) Word32
0
{-# INLINE simpleDimension #-}
rowMajorStrides
  :: Integral a
  => [a]
  
  -> [a]
rowMajorStrides :: forall a. Integral a => [a] -> [a]
rowMajorStrides = forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) a
1
colMajorStrides
  :: Integral a
  => [a]
  
  -> [a]
colMajorStrides :: forall a. Integral a => [a] -> [a]
colMajorStrides = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) a
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
init
data HalideDeviceInterface
data RawHalideBuffer = RawHalideBuffer
  { RawHalideBuffer -> Word64
halideBufferDevice :: !Word64
  , RawHalideBuffer -> Ptr HalideDeviceInterface
halideBufferDeviceInterface :: !(Ptr HalideDeviceInterface)
  , RawHalideBuffer -> Ptr Word8
halideBufferHost :: !(Ptr Word8)
  , RawHalideBuffer -> Word64
halideBufferFlags :: !Word64
  , RawHalideBuffer -> HalideType
halideBufferType :: !HalideType
  , RawHalideBuffer -> Int32
halideBufferDimensions :: !Int32
  , RawHalideBuffer -> Ptr HalideDimension
halideBufferDim :: !(Ptr HalideDimension)
  , RawHalideBuffer -> Ptr ()
halideBufferPadding :: !(Ptr ())
  }
  deriving stock (Int -> RawHalideBuffer -> ShowS
[RawHalideBuffer] -> ShowS
RawHalideBuffer -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RawHalideBuffer] -> ShowS
$cshowList :: [RawHalideBuffer] -> ShowS
show :: RawHalideBuffer -> String
$cshow :: RawHalideBuffer -> String
showsPrec :: Int -> RawHalideBuffer -> ShowS
$cshowsPrec :: Int -> RawHalideBuffer -> ShowS
Show, RawHalideBuffer -> RawHalideBuffer -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RawHalideBuffer -> RawHalideBuffer -> Bool
$c/= :: RawHalideBuffer -> RawHalideBuffer -> Bool
== :: RawHalideBuffer -> RawHalideBuffer -> Bool
$c== :: RawHalideBuffer -> RawHalideBuffer -> Bool
Eq)
newtype HalideBuffer (n :: Nat) (a :: Type) = HalideBuffer {forall (n :: Nat) a. HalideBuffer n a -> RawHalideBuffer
unHalideBuffer :: RawHalideBuffer}
  deriving stock (Int -> HalideBuffer n a -> ShowS
forall (n :: Nat) a. Int -> HalideBuffer n a -> ShowS
forall (n :: Nat) a. [HalideBuffer n a] -> ShowS
forall (n :: Nat) a. HalideBuffer n a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HalideBuffer n a] -> ShowS
$cshowList :: forall (n :: Nat) a. [HalideBuffer n a] -> ShowS
show :: HalideBuffer n a -> String
$cshow :: forall (n :: Nat) a. HalideBuffer n a -> String
showsPrec :: Int -> HalideBuffer n a -> ShowS
$cshowsPrec :: forall (n :: Nat) a. Int -> HalideBuffer n a -> ShowS
Show, HalideBuffer n a -> HalideBuffer n a -> Bool
forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideBuffer n a -> HalideBuffer n a -> Bool
$c/= :: forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
== :: HalideBuffer n a -> HalideBuffer n a -> Bool
$c== :: forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
Eq)
importHalide
instance Storable RawHalideBuffer where
  sizeOf :: RawHalideBuffer -> Int
sizeOf RawHalideBuffer
_ = Int
56
  alignment :: RawHalideBuffer -> Int
alignment RawHalideBuffer
_ = Int
8
  peek :: Ptr RawHalideBuffer -> IO RawHalideBuffer
peek Ptr RawHalideBuffer
p =
    Word64
-> Ptr HalideDeviceInterface
-> Ptr Word8
-> Word64
-> HalideType
-> Int32
-> Ptr HalideDimension
-> Ptr ()
-> RawHalideBuffer
RawHalideBuffer
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
0 
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
8 
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
16 
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
24 
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
32 
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
36 
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
40 
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
48 
  poke :: Ptr RawHalideBuffer -> RawHalideBuffer -> IO ()
poke Ptr RawHalideBuffer
p RawHalideBuffer
x = do
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
0 (RawHalideBuffer -> Word64
halideBufferDevice RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
8 (RawHalideBuffer -> Ptr HalideDeviceInterface
halideBufferDeviceInterface RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
16 (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
24 (RawHalideBuffer -> Word64
halideBufferFlags RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
32 (RawHalideBuffer -> HalideType
halideBufferType RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
36 (RawHalideBuffer -> Int32
halideBufferDimensions RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
40 (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
x)
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
48 (RawHalideBuffer -> Ptr ()
halideBufferPadding RawHalideBuffer
x)
bufferFromPtrShapeStrides
  :: forall n a b
   . (HasCallStack, KnownNat n, IsHalideType a)
  => Ptr a
  
  -> [Int]
  
  -> [Int]
  
  -> (Ptr (HalideBuffer n a) -> IO b)
  
  -> IO b
bufferFromPtrShapeStrides :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShapeStrides Ptr a
p [Int]
shape [Int]
stride Ptr (HalideBuffer n a) -> IO b
action =
  forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> HalideDimension
simpleDimension [Int]
shape [Int]
stride) forall a b. (a -> b) -> a -> b
$ \Int
n Ptr HalideDimension
dim -> do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"specified wrong number of dimensions: "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n
          forall a. Semigroup a => a -> a -> a
<> String
"; expected "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
          forall a. Semigroup a => a -> a -> a
<> String
" from the type declaration"
    let !buffer :: RawHalideBuffer
buffer =
          RawHalideBuffer
            { halideBufferDevice :: Word64
halideBufferDevice = Word64
0
            , halideBufferDeviceInterface :: Ptr HalideDeviceInterface
halideBufferDeviceInterface = forall a. Ptr a
nullPtr
            , halideBufferHost :: Ptr Word8
halideBufferHost = forall a b. Ptr a -> Ptr b
castPtr Ptr a
p
            , halideBufferFlags :: Word64
halideBufferFlags = Word64
0
            , halideBufferType :: HalideType
halideBufferType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
            , halideBufferDimensions :: Int32
halideBufferDimensions = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
            , halideBufferDim :: Ptr HalideDimension
halideBufferDim = Ptr HalideDimension
dim
            , halideBufferPadding :: Ptr ()
halideBufferPadding = forall a. Ptr a
nullPtr
            }
    forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with RawHalideBuffer
buffer forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
bufferPtr -> do
      b
r <- Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
bufferPtr)
      Bool
hasDataOnDevice <-
        forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* bufferPtr)->device } |]
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
hasDataOnDevice forall a b. (a -> b) -> a -> b
$
        forall a. HasCallStack => String -> a
error String
"the Buffer still references data on the device; did you forget to call copyToHost?"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r
bufferFromPtrShape
  :: (HasCallStack, KnownNat n, IsHalideType a)
  => Ptr a
  
  -> [Int]
  
  -> (Ptr (HalideBuffer n a) -> IO b)
  -> IO b
bufferFromPtrShape :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
p [Int]
shape = forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShapeStrides Ptr a
p [Int]
shape (forall a. Integral a => [a] -> [a]
colMajorStrides [Int]
shape)
class (KnownNat n, IsHalideType a) => IsHalideBuffer t n a where
  withHalideBufferImpl :: t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer :: forall n a t b. IsHalideBuffer t n a => t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer :: forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer = forall t (n :: Nat) a b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBufferImpl @t @n @a
instance IsHalideType a => IsHalideBuffer (S.Vector a) 1 a where
  withHalideBufferImpl :: forall b. Vector a -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl Vector a
v Ptr (HalideBuffer 1 a) -> IO b
f =
    forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
dataPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
dataPtr [forall a. Storable a => Vector a -> Int
S.length Vector a
v] Ptr (HalideBuffer 1 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer (S.MVector RealWorld a) 1 a where
  withHalideBufferImpl :: forall b.
MVector RealWorld a -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl MVector RealWorld a
v Ptr (HalideBuffer 1 a) -> IO b
f =
    forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
SM.unsafeWith MVector RealWorld a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
dataPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
dataPtr [forall a s. Storable a => MVector s a -> Int
SM.length MVector RealWorld a
v] Ptr (HalideBuffer 1 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer [a] 1 a where
  withHalideBufferImpl :: forall b. [a] -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl [a]
v = forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer (forall a. Storable a => [a] -> Vector a
S.fromList [a]
v)
instance IsHalideType a => IsHalideBuffer [[a]] 2 a where
  withHalideBufferImpl :: forall b. [[a]] -> (Ptr (HalideBuffer 2 a) -> IO b) -> IO b
withHalideBufferImpl [[a]]
xs Ptr (HalideBuffer 2 a) -> IO b
f = do
    let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[a]]
xs
        d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[a]]
xs)
        
        v :: Vector a
v = forall a. Storable a => [a] -> Vector a
S.fromList (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
List.concat (forall a. [[a]] -> [[a]]
List.transpose [[a]]
xs))
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
    forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1] Ptr (HalideBuffer 2 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer [[[a]]] 3 a where
  withHalideBufferImpl :: forall b. [[[a]]] -> (Ptr (HalideBuffer 3 a) -> IO b) -> IO b
withHalideBufferImpl [[[a]]]
xs Ptr (HalideBuffer 3 a) -> IO b
f = do
    let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[[a]]]
xs
        d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[[a]]]
xs)
        d2 :: Int
d2 = if Int
d1 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head [[[a]]]
xs))
        
        v :: Vector a
v =
          forall a. Storable a => [a] -> Vector a
S.fromList
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
List.concat
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
List.concatMap forall a. [[a]] -> [[a]]
List.transpose
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose
            forall a b. (a -> b) -> a -> b
$ [[[a]]]
xs
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1 forall a. Num a => a -> a -> a
* Int
d2) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
    forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1, Int
d2] Ptr (HalideBuffer 3 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer [[[[a]]]] 4 a where
  withHalideBufferImpl :: forall b. [[[[a]]]] -> (Ptr (HalideBuffer 4 a) -> IO b) -> IO b
withHalideBufferImpl [[[[a]]]]
xs Ptr (HalideBuffer 4 a) -> IO b
f = do
    let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[[[a]]]]
xs
        d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[[[a]]]]
xs)
        d2 :: Int
d2 = if Int
d1 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head [[[[a]]]]
xs))
        d3 :: Int
d3 = if Int
d2 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head (forall a. [a] -> a
head [[[[a]]]]
xs)))
        
        v :: Vector a
v =
          forall a. Storable a => [a] -> Vector a
S.fromList
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose)
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. [[a]] -> [[a]]
List.transpose forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose)
            forall a b. (a -> b) -> a -> b
$ [[[[a]]]]
xs
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1 forall a. Num a => a -> a -> a
* Int
d2 forall a. Num a => a -> a -> a
* Int
d3) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
    forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
      forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1, Int
d2, Int
d3] Ptr (HalideBuffer 4 a) -> IO b
f
whenM :: Monad m => m Bool -> m () -> m ()
whenM :: forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM m Bool
cond m ()
f =
  m Bool
cond forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> m ()
f
    Bool
False -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
allocaCpuBuffer
  :: forall n a b
   . (HasCallStack, KnownNat n, IsHalideType a)
  => [Int]
  -> (Ptr (HalideBuffer n a) -> IO b)
  -> IO b
allocaCpuBuffer :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
[Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaCpuBuffer = forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Target -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaBuffer Target
hostTarget
getTotalBytes :: Ptr RawHalideBuffer -> IO Int
getTotalBytes :: Ptr RawHalideBuffer -> IO Int
getTotalBytes Ptr RawHalideBuffer
buf = do
  forall a b. (Integral a, Num b) => a -> b
fromIntegral
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.block| size_t {
          auto const& b = *$(const halide_buffer_t* buf);
          auto const n = std::accumulate(b.dim, b.dim + b.dimensions, size_t{1},
                                         [](auto acc, auto const& dim) { return acc * dim.extent; });
          return n * (b.type.bits * b.type.lanes / 8);
        } |]
allocateHostMemory :: Ptr RawHalideBuffer -> IO ()
allocateHostMemory :: Ptr RawHalideBuffer -> IO ()
allocateHostMemory Ptr RawHalideBuffer
buf = do
  Ptr Word8
ptr <- forall a. Int -> IO (Ptr a)
mallocBytes forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr RawHalideBuffer -> IO Int
getTotalBytes Ptr RawHalideBuffer
buf
  [CU.block| void { $(halide_buffer_t* buf)->host = $(uint8_t* ptr); } |]
freeHostMemory :: Ptr RawHalideBuffer -> IO ()
freeHostMemory :: Ptr RawHalideBuffer -> IO ()
freeHostMemory Ptr RawHalideBuffer
buf = do
  Ptr Word8
ptr <-
    [CU.block| uint8_t* {
      auto& b = *$(halide_buffer_t* buf);
      auto const p = b.host;
      b.host = nullptr;
      return p;
    } |]
  forall a. Ptr a -> IO ()
free Ptr Word8
ptr
allocateDeviceMemory :: Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory :: Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory Ptr HalideDeviceInterface
interface Ptr RawHalideBuffer
buf = do
  [CU.block| void {
    auto const* interface = $(const halide_device_interface_t* interface);
    interface->device_malloc(nullptr, $(halide_buffer_t* buf), interface);
  } |]
freeDeviceMemory :: HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory :: HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory Ptr RawHalideBuffer
buf = do
  Ptr HalideDeviceInterface
deviceInterface <-
    [CU.exp| const halide_device_interface_t* { $(const halide_buffer_t* buf)->device_interface } |]
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr HalideDeviceInterface
deviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall a b. (a -> b) -> a -> b
$
    forall a. HasCallStack => String -> a
error String
"cannot free device memory: device_interface is NULL"
  [CU.block| void {
    $(halide_buffer_t* buf)->device_interface->device_free(nullptr, $(halide_buffer_t* buf));
    $(halide_buffer_t* buf)->device = 0;
  } |]
allocaBuffer
  :: forall n a b
   . (HasCallStack, KnownNat n, IsHalideType a)
  => Target
  -> [Int]
  -> (Ptr (HalideBuffer n a) -> IO b)
  -> IO b
allocaBuffer :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Target -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaBuffer Target
target [Int]
shape Ptr (HalideBuffer n a) -> IO b
action = do
  Ptr HalideDeviceInterface
deviceInterface <- Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface Target
target
  let onHost :: Bool
onHost = Ptr HalideDeviceInterface
deviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr
  forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> HalideDimension
simpleDimension [Int]
shape (forall a. Integral a => [a] -> [a]
colMajorStrides [Int]
shape)) forall a b. (a -> b) -> a -> b
$ \Int
n Ptr HalideDimension
dim -> do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))) forall a b. (a -> b) -> a -> b
$
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
        String
"specified wrong number of dimensions: "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n
          forall a. Semigroup a => a -> a -> a
<> String
"; expected "
          forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
          forall a. Semigroup a => a -> a -> a
<> String
" from the type declaration"
    let rawBuffer :: RawHalideBuffer
rawBuffer =
          RawHalideBuffer
            { halideBufferDevice :: Word64
halideBufferDevice = Word64
0
            , halideBufferDeviceInterface :: Ptr HalideDeviceInterface
halideBufferDeviceInterface = forall a. Ptr a
nullPtr
            , halideBufferHost :: Ptr Word8
halideBufferHost = forall a. Ptr a
nullPtr
            , halideBufferFlags :: Word64
halideBufferFlags = Word64
0
            , halideBufferType :: HalideType
halideBufferType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
            , halideBufferDimensions :: Int32
halideBufferDimensions = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
            , halideBufferDim :: Ptr HalideDimension
halideBufferDim = Ptr HalideDimension
dim
            , halideBufferPadding :: Ptr ()
halideBufferPadding = forall a. Ptr a
nullPtr
            }
    forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with RawHalideBuffer
rawBuffer forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
buf -> do
      let allocate :: Ptr RawHalideBuffer -> IO ()
allocate
            | Bool
onHost = Ptr RawHalideBuffer -> IO ()
allocateHostMemory
            | Bool
otherwise = Ptr HalideDeviceInterface -> Ptr RawHalideBuffer -> IO ()
allocateDeviceMemory Ptr HalideDeviceInterface
deviceInterface
      let deallocate :: Ptr RawHalideBuffer -> IO ()
deallocate
            | Bool
onHost = Ptr RawHalideBuffer -> IO ()
freeHostMemory
            | Bool
otherwise = HasCallStack => Ptr RawHalideBuffer -> IO ()
freeDeviceMemory
      forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (Ptr RawHalideBuffer -> IO ()
allocate Ptr RawHalideBuffer
buf) (Ptr RawHalideBuffer -> IO ()
deallocate Ptr RawHalideBuffer
buf) forall a b. (a -> b) -> a -> b
$ do
        b
r <- Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
buf)
        Bool
isHostNull <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* buf)->host == nullptr } |]
        Bool
isDeviceNull <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* buf)->device == 0 } |]
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
onHost Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isDeviceNull) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
          String
"buffer was allocated on host, but its device pointer is not NULL"
            forall a. Semigroup a => a -> a -> a
<> String
"; did you forget a copyToHost in your pipeline?"
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
onHost Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isHostNull) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
          String
"buffer was allocated on device, but its host pointer is not NULL"
            forall a. Semigroup a => a -> a -> a
<> String
"; did you add an extra copyToHost?"
        forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r
getDeviceInterface :: Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface :: Target -> IO (Ptr HalideDeviceInterface)
getDeviceInterface Target
target =
  case DeviceAPI
device of
    DeviceAPI
DeviceNone -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Ptr a
nullPtr
    DeviceAPI
DeviceHost -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Ptr a
nullPtr
    DeviceAPI
_ ->
      forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
        [C.throwBlock| const halide_device_interface_t* {
          return handle_halide_exceptions([=](){
            auto const device = static_cast<Halide::DeviceAPI>($(int api));
            auto const& target = *$(const Halide::Target* target');
            return Halide::get_device_interface_for_device_api(device, target, "getDeviceInterface");
          });
        } |]
  where
    device :: DeviceAPI
device@(forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) = Target -> DeviceAPI
deviceAPIForTarget Target
target
isDeviceDirty :: Ptr RawHalideBuffer -> IO Bool
isDeviceDirty :: Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p =
  forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const halide_buffer_t* p)->device_dirty() } |]
setDeviceDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CBool
b) Ptr RawHalideBuffer
p =
  [CU.exp| void { $(halide_buffer_t* p)->set_device_dirty($(bool b)) } |]
isHostDirty :: Ptr RawHalideBuffer -> IO Bool
isHostDirty :: Ptr RawHalideBuffer -> IO Bool
isHostDirty Ptr RawHalideBuffer
p =
  forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const halide_buffer_t* p)->host_dirty() } |]
setHostDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setHostDirty :: Bool -> Ptr RawHalideBuffer -> IO ()
setHostDirty (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CBool
b) Ptr RawHalideBuffer
p =
  [CU.exp| void { $(halide_buffer_t* p)->set_host_dirty($(bool b)) } |]
bufferCopyToHost :: HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost :: HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost Ptr RawHalideBuffer
p = forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p) forall a b. (a -> b) -> a -> b
$ do
  RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek Ptr RawHalideBuffer
p
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDeviceInterface forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"device_dirty is set, but device_interface is NULL"
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferHost forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"host is NULL, did you forget to allocate memory?"
  [CU.block| void {
    auto& buf = *$(halide_buffer_t* p);
    buf.device_interface->copy_to_host(nullptr, &buf);
  } |]
  forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
    String
"device_dirty is set right after a copy_to_host; something went wrong..."
checkNumberOfDimensions :: forall n. (HasCallStack, KnownNat n) => RawHalideBuffer -> IO ()
checkNumberOfDimensions :: forall (n :: Nat).
(HasCallStack, KnownNat n) =>
RawHalideBuffer -> IO ()
checkNumberOfDimensions RawHalideBuffer
raw = do
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) forall a. Eq a => a -> a -> Bool
== RawHalideBuffer
raw.halideBufferDimensions) forall a b. (a -> b) -> a -> b
$
    forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
      String
"type-level and runtime number of dimensions do not match: "
        forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
        forall a. Semigroup a => a -> a -> a
<> String
" != "
        forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RawHalideBuffer
raw.halideBufferDimensions
withCropped
  :: Ptr (HalideBuffer n a)
  
  -> Int
  
  -> Int
  
  -> Int
  
  -> (Ptr (HalideBuffer n a) -> IO b)
  
  -> IO b
withCropped :: forall (n :: Nat) a b.
Ptr (HalideBuffer n a)
-> Int -> Int -> Int -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withCropped
  (forall a b. Ptr a -> Ptr b
castPtr -> Ptr RawHalideBuffer
src)
  (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
d)
  (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
min)
  (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
extent)
  Ptr (HalideBuffer n a) -> IO b
action = do
    Int
rank <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_buffer_t* src)->dimensions } |]
    forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
dst ->
      forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
rank forall a b. (a -> b) -> a -> b
$ \Ptr HalideDimension
dstDim -> do
        [CU.block| void {
          auto const& src = *$(const halide_buffer_t* src);
          auto& dst = *$(halide_buffer_t* dst);
          auto const d = $(int d);
          dst = src;
          dst.dim = $(halide_dimension_t* dstDim);
          memcpy(dst.dim, src.dim, src.dimensions * sizeof(halide_dimension_t));
          if (dst.host != nullptr) {
            auto const shift = $(int min) - src.dim[d].min;
            dst.host += (shift * src.dim[d].stride) * ((src.type.bits + 7) / 8);
          }
          dst.dim[d].min = $(int min);
          dst.dim[d].extent = $(int extent);
          if (src.device != 0 && src.device_interface != nullptr) {
            src.device_interface->device_crop(nullptr, &src, &dst);
          }
        } |]
        Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
dst)
getBufferExtent :: forall n a. KnownNat n => Ptr (HalideBuffer n a) -> Int -> IO Int
getBufferExtent :: forall (n :: Nat) a.
KnownNat n =>
Ptr (HalideBuffer n a) -> Int -> IO Int
getBufferExtent (forall a b. Ptr a -> Ptr b
castPtr -> Ptr RawHalideBuffer
buf) (forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
d)
  | CInt
d forall a. Ord a => a -> a -> Bool
< forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) =
      forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_buffer_t* buf)->dim[$(int d)].extent } |]
  | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"index out of bounds"
peekScalar :: forall a. (HasCallStack, IsHalideType a) => Ptr (HalideBuffer 0 a) -> IO a
peekScalar :: forall a.
(HasCallStack, IsHalideType a) =>
Ptr (HalideBuffer 0 a) -> IO a
peekScalar Ptr (HalideBuffer 0 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 0 a)
p forall a b. (a -> b) -> a -> b
$ do
  RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 0 a)
p)
  forall (n :: Nat).
(HasCallStack, KnownNat n) =>
RawHalideBuffer -> IO ()
checkNumberOfDimensions @0 RawHalideBuffer
raw
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferHost forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
  forall a. Storable a => Ptr a -> IO a
peek forall a b. (a -> b) -> a -> b
$ forall a b. Ptr a -> Ptr b
castPtr @_ @a RawHalideBuffer
raw.halideBufferHost
type family NestedList (n :: Nat) (a :: Type) where
  NestedList 0 a = a
  NestedList 1 a = [a]
  NestedList 2 a = [[a]]
  NestedList 3 a = [[[a]]]
  NestedList 4 a = [[[[a]]]]
  NestedList 5 a = [[[[[a]]]]]
type family NestedListLevel (a :: Type) :: Nat where
  NestedListLevel [a] = 1 + NestedListLevel a
  NestedListLevel a = 0
type family NestedListType (a :: Type) :: Type where
  NestedListType [a] = NestedListType a
  NestedListType a = a
class
  ( KnownNat n
  , IsHalideType a
  , NestedList n a ~ b
  , NestedListLevel b ~ n
  , NestedListType b ~ a
  ) =>
  IsListPeek n a b
    | n a -> b
    , n b -> a
    , a b -> n
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer n a) -> IO b
instance
  (IsHalideType a, NestedListLevel [a] ~ 1, NestedListType [a] ~ a)
  => IsListPeek 1 a [a]
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer 1 a) -> IO [a]
peekToList Ptr (HalideBuffer 1 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 1 a)
p forall a b. (a -> b) -> a -> b
$ do
    RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 1 a)
p)
    (HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
    let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 ->
      forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0))
instance
  (IsHalideType a, NestedListLevel [[a]] ~ 2, NestedListType [[a]] ~ a)
  => IsListPeek 2 a [[a]]
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer 2 a) -> IO [[a]]
peekToList Ptr (HalideBuffer 2 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 2 a)
p forall a b. (a -> b) -> a -> b
$ do
    RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 2 a)
p)
    (HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
    (HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
    let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
      let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 ->
        forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr1 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1))
instance
  (IsHalideType a, NestedListLevel [[[a]]] ~ 3, NestedListType [[[a]]] ~ a)
  => IsListPeek 3 a [[[a]]]
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer 3 a) -> IO [[[a]]]
peekToList Ptr (HalideBuffer 3 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 3 a)
p forall a b. (a -> b) -> a -> b
$ do
    RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 3 a)
p)
    (HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
    (HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
    (HalideDimension Int32
min2 Int32
extent2 Int32
stride2 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
2
    let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
      let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 -> do
        let ptr2 :: Ptr a
ptr2 = Ptr a
ptr1 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1)
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent2 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i2 ->
          forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr2 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min2 forall a. Num a => a -> a -> a
+ Int32
stride2 forall a. Num a => a -> a -> a
* Int32
i2))
instance
  (IsHalideType a, NestedListLevel [[[[a]]]] ~ 4, NestedListType [[[[a]]]] ~ a)
  => IsListPeek 4 a [[[[a]]]]
  where
  peekToList :: HasCallStack => Ptr (HalideBuffer 4 a) -> IO [[[[a]]]]
peekToList Ptr (HalideBuffer 4 a)
p = forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost Ptr (HalideBuffer 4 a)
p forall a b. (a -> b) -> a -> b
$ do
    RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 4 a)
p)
    (HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
    (HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
    (HalideDimension Int32
min2 Int32
extent2 Int32
stride2 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
2
    (HalideDimension Int32
min3 Int32
extent3 Int32
stride3 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
3
    let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ptr a
ptr0 forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"host is NULL"
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
      let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 -> do
        let ptr2 :: Ptr a
ptr2 = Ptr a
ptr1 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1)
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent2 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i2 -> do
          let ptr3 :: Ptr a
ptr3 = Ptr a
ptr2 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min2 forall a. Num a => a -> a -> a
+ Int32
stride2 forall a. Num a => a -> a -> a
* Int32
i2)
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent3 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i3 ->
            forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr3 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min3 forall a. Num a => a -> a -> a
+ Int32
stride3 forall a. Num a => a -> a -> a
* Int32
i3))
withCopiedToHost :: Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost :: forall (n :: Nat) a b. Ptr (HalideBuffer n a) -> IO b -> IO b
withCopiedToHost (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer -> Ptr RawHalideBuffer
buf) IO b
action = do
  RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek Ptr RawHalideBuffer
buf
  let allocate :: IO ()
allocate = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ Ptr RawHalideBuffer -> IO ()
allocateHostMemory Ptr RawHalideBuffer
buf
      deallocate :: IO ()
deallocate = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ Ptr RawHalideBuffer -> IO ()
freeHostMemory Ptr RawHalideBuffer
buf
  forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ IO ()
allocate IO ()
deallocate forall a b. (a -> b) -> a -> b
$ do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RawHalideBuffer
raw.halideBufferDevice forall a. Eq a => a -> a -> Bool
/= Word64
0) forall a b. (a -> b) -> a -> b
$ do
      Bool -> Ptr RawHalideBuffer -> IO ()
setDeviceDirty Bool
True Ptr RawHalideBuffer
buf
      HasCallStack => Ptr RawHalideBuffer -> IO ()
bufferCopyToHost Ptr RawHalideBuffer
buf
    IO b
action