{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
module Language.Halide.Buffer
(
HalideBuffer (..)
, allocaCpuBuffer
, IsListPeek (..)
, IsHalideBuffer (..)
, withHalideBuffer
, bufferFromPtrShapeStrides
, bufferFromPtrShape
, RawHalideBuffer (..)
, HalideDimension (..)
, HalideDeviceInterface
, rowMajorStrides
, colMajorStrides
, isDeviceDirty
, isHostDirty
, bufferCopyToHost
)
where
import Control.Monad (forM, unless, when)
import Control.Monad.ST (RealWorld)
import Data.Foldable (foldl')
import Data.Int
import Data.Kind (Type)
import qualified Data.List as List
import Data.Proxy
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as SM
import Data.Word
import Foreign.Marshal.Array
import Foreign.Marshal.Utils
import Foreign.Ptr
import Foreign.Storable
import GHC.Stack (HasCallStack)
import GHC.TypeNats
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Cpp.Exception as C
import qualified Language.C.Inline.Unsafe as CU
import Language.Halide.Context
import Language.Halide.Type
data HalideDimension = HalideDimension
{ HalideDimension -> Int32
halideDimensionMin :: {-# UNPACK #-} !Int32
, HalideDimension -> Int32
halideDimensionExtent :: {-# UNPACK #-} !Int32
, HalideDimension -> Int32
halideDimensionStride :: {-# UNPACK #-} !Int32
, HalideDimension -> Word32
halideDimensionFlags :: {-# UNPACK #-} !Word32
}
deriving stock (ReadPrec [HalideDimension]
ReadPrec HalideDimension
Int -> ReadS HalideDimension
ReadS [HalideDimension]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideDimension]
$creadListPrec :: ReadPrec [HalideDimension]
readPrec :: ReadPrec HalideDimension
$creadPrec :: ReadPrec HalideDimension
readList :: ReadS [HalideDimension]
$creadList :: ReadS [HalideDimension]
readsPrec :: Int -> ReadS HalideDimension
$creadsPrec :: Int -> ReadS HalideDimension
Read, Int -> HalideDimension -> ShowS
[HalideDimension] -> ShowS
HalideDimension -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HalideDimension] -> ShowS
$cshowList :: [HalideDimension] -> ShowS
show :: HalideDimension -> String
$cshow :: HalideDimension -> String
showsPrec :: Int -> HalideDimension -> ShowS
$cshowsPrec :: Int -> HalideDimension -> ShowS
Show, HalideDimension -> HalideDimension -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideDimension -> HalideDimension -> Bool
$c/= :: HalideDimension -> HalideDimension -> Bool
== :: HalideDimension -> HalideDimension -> Bool
$c== :: HalideDimension -> HalideDimension -> Bool
Eq)
instance Storable HalideDimension where
sizeOf :: HalideDimension -> Int
sizeOf HalideDimension
_ = Int
16
{-# INLINE sizeOf #-}
alignment :: HalideDimension -> Int
alignment HalideDimension
_ = Int
4
{-# INLINE alignment #-}
peek :: Ptr HalideDimension -> IO HalideDimension
peek Ptr HalideDimension
p =
Int32 -> Int32 -> Int32 -> Word32 -> HalideDimension
HalideDimension
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
0
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
4
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
8
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideDimension
p Int
12
{-# INLINE peek #-}
poke :: Ptr HalideDimension -> HalideDimension -> IO ()
poke Ptr HalideDimension
p HalideDimension
x = do
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
0 (HalideDimension -> Int32
halideDimensionMin HalideDimension
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
4 (HalideDimension -> Int32
halideDimensionExtent HalideDimension
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
8 (HalideDimension -> Int32
halideDimensionStride HalideDimension
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideDimension
p Int
12 (HalideDimension -> Word32
halideDimensionFlags HalideDimension
x)
{-# INLINE poke #-}
simpleDimension :: Int -> Int -> HalideDimension
simpleDimension :: Int -> Int -> HalideDimension
simpleDimension Int
extent Int
stride = Int32 -> Int32 -> Int32 -> Word32 -> HalideDimension
HalideDimension Int32
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
extent) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
stride) Word32
0
{-# INLINE simpleDimension #-}
rowMajorStrides
:: Integral a
=> [a]
-> [a]
rowMajorStrides :: forall a. Integral a => [a] -> [a]
rowMajorStrides = forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) a
1
colMajorStrides
:: Integral a
=> [a]
-> [a]
colMajorStrides :: forall a. Integral a => [a] -> [a]
colMajorStrides = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) a
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
init
data HalideDeviceInterface
data RawHalideBuffer = RawHalideBuffer
{ RawHalideBuffer -> Word64
halideBufferDevice :: !Word64
, RawHalideBuffer -> Ptr HalideDeviceInterface
halideBufferDeviceInterface :: !(Ptr HalideDeviceInterface)
, RawHalideBuffer -> Ptr Word8
halideBufferHost :: !(Ptr Word8)
, RawHalideBuffer -> Word64
halideBufferFlags :: !Word64
, RawHalideBuffer -> HalideType
halideBufferType :: !HalideType
, RawHalideBuffer -> Int32
halideBufferDimensions :: !Int32
, RawHalideBuffer -> Ptr HalideDimension
halideBufferDim :: !(Ptr HalideDimension)
, RawHalideBuffer -> Ptr ()
halideBufferPadding :: !(Ptr ())
}
deriving stock (Int -> RawHalideBuffer -> ShowS
[RawHalideBuffer] -> ShowS
RawHalideBuffer -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RawHalideBuffer] -> ShowS
$cshowList :: [RawHalideBuffer] -> ShowS
show :: RawHalideBuffer -> String
$cshow :: RawHalideBuffer -> String
showsPrec :: Int -> RawHalideBuffer -> ShowS
$cshowsPrec :: Int -> RawHalideBuffer -> ShowS
Show, RawHalideBuffer -> RawHalideBuffer -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RawHalideBuffer -> RawHalideBuffer -> Bool
$c/= :: RawHalideBuffer -> RawHalideBuffer -> Bool
== :: RawHalideBuffer -> RawHalideBuffer -> Bool
$c== :: RawHalideBuffer -> RawHalideBuffer -> Bool
Eq)
newtype HalideBuffer (n :: Nat) (a :: Type) = HalideBuffer {forall (n :: Nat) a. HalideBuffer n a -> RawHalideBuffer
unHalideBuffer :: RawHalideBuffer}
deriving stock (Int -> HalideBuffer n a -> ShowS
forall (n :: Nat) a. Int -> HalideBuffer n a -> ShowS
forall (n :: Nat) a. [HalideBuffer n a] -> ShowS
forall (n :: Nat) a. HalideBuffer n a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HalideBuffer n a] -> ShowS
$cshowList :: forall (n :: Nat) a. [HalideBuffer n a] -> ShowS
show :: HalideBuffer n a -> String
$cshow :: forall (n :: Nat) a. HalideBuffer n a -> String
showsPrec :: Int -> HalideBuffer n a -> ShowS
$cshowsPrec :: forall (n :: Nat) a. Int -> HalideBuffer n a -> ShowS
Show, HalideBuffer n a -> HalideBuffer n a -> Bool
forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideBuffer n a -> HalideBuffer n a -> Bool
$c/= :: forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
== :: HalideBuffer n a -> HalideBuffer n a -> Bool
$c== :: forall (n :: Nat) a. HalideBuffer n a -> HalideBuffer n a -> Bool
Eq)
importHalide
instance Storable RawHalideBuffer where
sizeOf :: RawHalideBuffer -> Int
sizeOf RawHalideBuffer
_ = Int
56
alignment :: RawHalideBuffer -> Int
alignment RawHalideBuffer
_ = Int
8
peek :: Ptr RawHalideBuffer -> IO RawHalideBuffer
peek Ptr RawHalideBuffer
p =
Word64
-> Ptr HalideDeviceInterface
-> Ptr Word8
-> Word64
-> HalideType
-> Int32
-> Ptr HalideDimension
-> Ptr ()
-> RawHalideBuffer
RawHalideBuffer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
0
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
8
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
16
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
24
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
32
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
36
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
40
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr RawHalideBuffer
p Int
48
poke :: Ptr RawHalideBuffer -> RawHalideBuffer -> IO ()
poke Ptr RawHalideBuffer
p RawHalideBuffer
x = do
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
0 (RawHalideBuffer -> Word64
halideBufferDevice RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
8 (RawHalideBuffer -> Ptr HalideDeviceInterface
halideBufferDeviceInterface RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
16 (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
24 (RawHalideBuffer -> Word64
halideBufferFlags RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
32 (RawHalideBuffer -> HalideType
halideBufferType RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
36 (RawHalideBuffer -> Int32
halideBufferDimensions RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
40 (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
x)
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr RawHalideBuffer
p Int
48 (RawHalideBuffer -> Ptr ()
halideBufferPadding RawHalideBuffer
x)
bufferFromPtrShapeStrides
:: forall n a b
. (HasCallStack, KnownNat n, IsHalideType a)
=> Ptr a
-> [Int]
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
bufferFromPtrShapeStrides :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShapeStrides Ptr a
p [Int]
shape [Int]
stride Ptr (HalideBuffer n a) -> IO b
action =
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> HalideDimension
simpleDimension [Int]
shape [Int]
stride) forall a b. (a -> b) -> a -> b
$ \Int
n Ptr HalideDimension
dim -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n forall a. Eq a => a -> a -> Bool
== forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"specified wrong number of dimensions: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
n
forall a. Semigroup a => a -> a -> a
<> String
"; expected "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
forall a. Semigroup a => a -> a -> a
<> String
" from the type declaration"
let !buffer :: RawHalideBuffer
buffer =
RawHalideBuffer
{ halideBufferDevice :: Word64
halideBufferDevice = Word64
0
, halideBufferDeviceInterface :: Ptr HalideDeviceInterface
halideBufferDeviceInterface = forall a. Ptr a
nullPtr
, halideBufferHost :: Ptr Word8
halideBufferHost = forall a b. Ptr a -> Ptr b
castPtr Ptr a
p
, halideBufferFlags :: Word64
halideBufferFlags = Word64
0
, halideBufferType :: HalideType
halideBufferType = forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
, halideBufferDimensions :: Int32
halideBufferDimensions = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
, halideBufferDim :: Ptr HalideDimension
halideBufferDim = Ptr HalideDimension
dim
, halideBufferPadding :: Ptr ()
halideBufferPadding = forall a. Ptr a
nullPtr
}
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with RawHalideBuffer
buffer forall a b. (a -> b) -> a -> b
$ \Ptr RawHalideBuffer
bufferPtr -> do
b
r <- Ptr (HalideBuffer n a) -> IO b
action (forall a b. Ptr a -> Ptr b
castPtr Ptr RawHalideBuffer
bufferPtr)
Bool
hasDataOnDevice <-
forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(halide_buffer_t* bufferPtr)->device } |]
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
hasDataOnDevice forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"the Buffer still references data on the device; did you forget to call copyToHost?"
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r
bufferFromPtrShape
:: (HasCallStack, KnownNat n, IsHalideType a)
=> Ptr a
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
bufferFromPtrShape :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
p [Int]
shape = forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShapeStrides Ptr a
p [Int]
shape (forall a. Integral a => [a] -> [a]
colMajorStrides [Int]
shape)
class (KnownNat n, IsHalideType a) => IsHalideBuffer t n a where
withHalideBufferImpl :: t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer :: forall n a t b. IsHalideBuffer t n a => t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer :: forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer = forall t (n :: Nat) a b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBufferImpl @t @n @a
instance IsHalideType a => IsHalideBuffer (S.Vector a) 1 a where
withHalideBufferImpl :: forall b. Vector a -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl Vector a
v Ptr (HalideBuffer 1 a) -> IO b
f =
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
dataPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
dataPtr [forall a. Storable a => Vector a -> Int
S.length Vector a
v] Ptr (HalideBuffer 1 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer (S.MVector RealWorld a) 1 a where
withHalideBufferImpl :: forall b.
MVector RealWorld a -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl MVector RealWorld a
v Ptr (HalideBuffer 1 a) -> IO b
f =
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
SM.unsafeWith MVector RealWorld a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
dataPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
dataPtr [forall a s. Storable a => MVector s a -> Int
SM.length MVector RealWorld a
v] Ptr (HalideBuffer 1 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer [a] 1 a where
withHalideBufferImpl :: forall b. [a] -> (Ptr (HalideBuffer 1 a) -> IO b) -> IO b
withHalideBufferImpl [a]
v = forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer (forall a. Storable a => [a] -> Vector a
S.fromList [a]
v)
instance IsHalideType a => IsHalideBuffer [[a]] 2 a where
withHalideBufferImpl :: forall b. [[a]] -> (Ptr (HalideBuffer 2 a) -> IO b) -> IO b
withHalideBufferImpl [[a]]
xs Ptr (HalideBuffer 2 a) -> IO b
f = do
let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[a]]
xs
d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[a]]
xs)
v :: Vector a
v = forall a. Storable a => [a] -> Vector a
S.fromList (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
List.concat (forall a. [[a]] -> [[a]]
List.transpose [[a]]
xs))
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1] Ptr (HalideBuffer 2 a) -> IO b
f
instance IsHalideType a => IsHalideBuffer [[[a]]] 3 a where
withHalideBufferImpl :: forall b. [[[a]]] -> (Ptr (HalideBuffer 3 a) -> IO b) -> IO b
withHalideBufferImpl [[[a]]]
xs Ptr (HalideBuffer 3 a) -> IO b
f = do
let d0 :: Int
d0 = forall (t :: * -> *) a. Foldable t => t a -> Int
length [[[a]]]
xs
d1 :: Int
d1 = if Int
d0 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[[a]]]
xs)
d2 :: Int
d2 = if Int
d1 forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head (forall a. [a] -> a
head [[[a]]]
xs))
v :: Vector a
v =
forall a. Storable a => [a] -> Vector a
S.fromList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
List.concat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
List.concatMap forall a. [[a]] -> [[a]]
List.transpose
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [[a]] -> [[a]]
List.transpose
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [[a]] -> [[a]]
List.transpose
forall a b. (a -> b) -> a -> b
$ [[[a]]]
xs
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Storable a => Vector a -> Int
S.length Vector a
v forall a. Eq a => a -> a -> Bool
/= Int
d0 forall a. Num a => a -> a -> a
* Int
d1 forall a. Num a => a -> a -> a
* Int
d2) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"list doesn't have a regular shape (i.e. rows have varying number of elements)"
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
S.unsafeWith Vector a
v forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int
d0, Int
d1, Int
d2] Ptr (HalideBuffer 3 a) -> IO b
f
whenM :: Monad m => m Bool -> m () -> m ()
whenM :: forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM m Bool
cond m ()
f =
m Bool
cond forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> m ()
f
Bool
False -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
allocaCpuBuffer
:: forall n a b
. (HasCallStack, KnownNat n, IsHalideType a)
=> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
allocaCpuBuffer :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
[Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaCpuBuffer [Int]
shape Ptr (HalideBuffer n a) -> IO b
action =
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
numElements forall a b. (a -> b) -> a -> b
$ \Ptr a
cpuPtr ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Ptr a -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
bufferFromPtrShape Ptr a
cpuPtr [Int]
shape forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
buf -> do
b
r <- Ptr (HalideBuffer n a) -> IO b
action Ptr (HalideBuffer n a)
buf
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
buf)) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"device_dirty is set on a CPU-only buffer; "
forall a. Semigroup a => a -> a -> a
<> String
"did you forget a copyToHost in your pipeline?"
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r
where
numElements :: Int
numElements = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Num a => a -> a -> a
(*) Int
1 [Int]
shape
isDeviceDirty :: Ptr RawHalideBuffer -> IO Bool
isDeviceDirty :: Ptr RawHalideBuffer -> IO Bool
isDeviceDirty Ptr RawHalideBuffer
p =
forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const halide_buffer_t* p)->device_dirty() } |]
isHostDirty :: Ptr RawHalideBuffer -> IO Bool
isHostDirty :: Ptr RawHalideBuffer -> IO Bool
isHostDirty Ptr RawHalideBuffer
p =
forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const halide_buffer_t* p)->host_dirty() } |]
bufferCopyToHost :: Ptr RawHalideBuffer -> IO ()
bufferCopyToHost :: Ptr RawHalideBuffer -> IO ()
bufferCopyToHost Ptr RawHalideBuffer
p =
[C.throwBlock| void {
auto& buf = *$(halide_buffer_t* p);
if (buf.device_dirty()) {
if (buf.device_interface == nullptr) {
throw std::runtime_error{"bufferCopyToHost: device_dirty is set, "
"but device_interface is NULL"};
}
if (buf.host == nullptr) {
throw std::runtime_error{"bufferCopyToHost: host is NULL; "
"did you forget to allocate memory?"};
}
buf.device_interface->copy_to_host(nullptr, &buf);
}
} |]
checkNumberOfDimensions :: forall n. (HasCallStack, KnownNat n) => RawHalideBuffer -> IO ()
checkNumberOfDimensions :: forall (n :: Nat).
(HasCallStack, KnownNat n) =>
RawHalideBuffer -> IO ()
checkNumberOfDimensions RawHalideBuffer
raw = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) forall a. Eq a => a -> a -> Bool
== RawHalideBuffer
raw.halideBufferDimensions) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"type-level and runtime number of dimensions do not match: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
forall a. Semigroup a => a -> a -> a
<> String
" != "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show RawHalideBuffer
raw.halideBufferDimensions
class IsListPeek a where
type ListPeekElem a :: Type
peekToList :: HasCallStack => Ptr a -> IO [ListPeekElem a]
instance IsHalideType a => IsListPeek (HalideBuffer 0 a) where
type ListPeekElem (HalideBuffer 0 a) = a
peekToList :: HasCallStack =>
Ptr (HalideBuffer 0 a) -> IO [ListPeekElem (HalideBuffer 0 a)]
peekToList Ptr (HalideBuffer 0 a)
p = do
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer 0 a)
p)) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"cannot peek data from device; call bufferCopyToHost first"
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 0 a)
p)
forall (n :: Nat).
(HasCallStack, KnownNat n) =>
RawHalideBuffer -> IO ()
checkNumberOfDimensions @0 RawHalideBuffer
raw
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Storable a => Ptr a -> IO a
peek forall a b. (a -> b) -> a -> b
$ forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
instance IsHalideType a => IsListPeek (HalideBuffer 1 a) where
type ListPeekElem (HalideBuffer 1 a) = a
peekToList :: HasCallStack =>
Ptr (HalideBuffer 1 a) -> IO [ListPeekElem (HalideBuffer 1 a)]
peekToList Ptr (HalideBuffer 1 a)
p = do
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer 1 a)
p)) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"cannot peek data from device; call bufferCopyToHost first"
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 1 a)
p)
(HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 ->
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0))
instance IsHalideType a => IsListPeek (HalideBuffer 2 a) where
type ListPeekElem (HalideBuffer 2 a) = [a]
peekToList :: HasCallStack =>
Ptr (HalideBuffer 2 a) -> IO [ListPeekElem (HalideBuffer 2 a)]
peekToList Ptr (HalideBuffer 2 a)
p = do
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer 2 a)
p)) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"cannot peek data from device; call bufferCopyToHost first"
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 2 a)
p)
(HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
(HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 ->
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr1 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1))
instance IsHalideType a => IsListPeek (HalideBuffer 3 a) where
type ListPeekElem (HalideBuffer 3 a) = [[a]]
peekToList :: HasCallStack =>
Ptr (HalideBuffer 3 a) -> IO [ListPeekElem (HalideBuffer 3 a)]
peekToList Ptr (HalideBuffer 3 a)
p = do
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM (Ptr RawHalideBuffer -> IO Bool
isDeviceDirty (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer 3 a)
p)) forall a b. (a -> b) -> a -> b
$
forall a. HasCallStack => String -> a
error String
"cannot peek data from device; call bufferCopyToHost first"
RawHalideBuffer
raw <- forall a. Storable a => Ptr a -> IO a
peek (forall a b. Ptr a -> Ptr b
castPtr @_ @RawHalideBuffer Ptr (HalideBuffer 3 a)
p)
(HalideDimension Int32
min0 Int32
extent0 Int32
stride0 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
0
(HalideDimension Int32
min1 Int32
extent1 Int32
stride1 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
1
(HalideDimension Int32
min2 Int32
extent2 Int32
stride2 Word32
_) <- forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (RawHalideBuffer -> Ptr HalideDimension
halideBufferDim RawHalideBuffer
raw) Int
2
let ptr0 :: Ptr a
ptr0 = forall a b. Ptr a -> Ptr b
castPtr @_ @a (RawHalideBuffer -> Ptr Word8
halideBufferHost RawHalideBuffer
raw)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent0 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i0 -> do
let ptr1 :: Ptr a
ptr1 = Ptr a
ptr0 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min0 forall a. Num a => a -> a -> a
+ Int32
stride0 forall a. Num a => a -> a -> a
* Int32
i0)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent1 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i1 -> do
let ptr2 :: Ptr a
ptr2 = Ptr a
ptr1 forall a. Storable a => Ptr a -> Int -> Ptr a
`advancePtr` forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min1 forall a. Num a => a -> a -> a
+ Int32
stride1 forall a. Num a => a -> a -> a
* Int32
i1)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int32
0 .. Int32
extent2 forall a. Num a => a -> a -> a
- Int32
1] forall a b. (a -> b) -> a -> b
$ \Int32
i2 ->
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
ptr2 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
min2 forall a. Num a => a -> a -> a
+ Int32
stride2 forall a. Num a => a -> a -> a
* Int32
i2))