{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unused-local-binds -Wno-unused-matches #-}
module Language.Halide.Type
( HalideTypeCode (..)
, HalideType (..)
, IsHalideType (..)
, CxxExpr
, CxxVar
, CxxRVar
, CxxVarOrRVar
, CxxFunc
, CxxParameter
, CxxImageParam
, CxxVector
, CxxUserContext
, CxxCallable
, CxxTarget
, CxxStageSchedule
, CxxString
, Arguments (..)
, Length
, Append
, Concat
, argumentsAppend
, FunctionArguments
, FunctionReturn
, All
, UnCurry (..)
, Curry (..)
, defineIsHalideTypeInstances
, instanceHasCxxVector
, HasCxxVector (..)
, instanceCxxConstructible
, CxxConstructible (..)
)
where
import Data.Coerce
import Data.Constraint
import Data.Int
import Data.Kind (Type)
import Data.Text qualified as T
import Data.Word
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.ForeignPtr (mallocForeignPtrAlignedBytes)
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Haskell.TH qualified as TH
import Language.Haskell.TH.Syntax (Lift)
data CxxExpr
data CxxVar
data CxxRVar
data CxxVarOrRVar
data CxxParameter
data CxxImageParam
data CxxFunc
data CxxUserContext
data CxxCallable
data CxxTarget
data CxxVector a
data CxxStageSchedule
data CxxString
class CxxConstructible a where
cxxSizeOf :: Int
cxxConstruct :: (Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstructWithDeleter :: Int -> FinalizerPtr a -> (Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstructWithDeleter :: forall a.
Int -> FinalizerPtr a -> (Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstructWithDeleter Int
size FinalizerPtr a
deleter Ptr a -> IO ()
constructor = do
ForeignPtr a
fp <- forall a. Int -> Int -> IO (ForeignPtr a)
mallocForeignPtrAlignedBytes Int
size Int
align
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fp Ptr a -> IO ()
constructor
forall a. FinalizerPtr a -> ForeignPtr a -> IO ()
addForeignPtrFinalizer FinalizerPtr a
deleter ForeignPtr a
fp
forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr a
fp
where
align :: Int
align = Int
64
data HalideTypeCode
= HalideTypeInt
| HalideTypeUInt
| HalideTypeFloat
| HalideTypeHandle
| HalideTypeBfloat
deriving stock (ReadPrec [HalideTypeCode]
ReadPrec HalideTypeCode
Int -> ReadS HalideTypeCode
ReadS [HalideTypeCode]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideTypeCode]
$creadListPrec :: ReadPrec [HalideTypeCode]
readPrec :: ReadPrec HalideTypeCode
$creadPrec :: ReadPrec HalideTypeCode
readList :: ReadS [HalideTypeCode]
$creadList :: ReadS [HalideTypeCode]
readsPrec :: Int -> ReadS HalideTypeCode
$creadsPrec :: Int -> ReadS HalideTypeCode
Read, Int -> HalideTypeCode -> ShowS
[HalideTypeCode] -> ShowS
HalideTypeCode -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [HalideTypeCode] -> ShowS
$cshowList :: [HalideTypeCode] -> ShowS
show :: HalideTypeCode -> [Char]
$cshow :: HalideTypeCode -> [Char]
showsPrec :: Int -> HalideTypeCode -> ShowS
$cshowsPrec :: Int -> HalideTypeCode -> ShowS
Show, HalideTypeCode -> HalideTypeCode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideTypeCode -> HalideTypeCode -> Bool
$c/= :: HalideTypeCode -> HalideTypeCode -> Bool
== :: HalideTypeCode -> HalideTypeCode -> Bool
$c== :: HalideTypeCode -> HalideTypeCode -> Bool
Eq, forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
liftTyped :: forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
$cliftTyped :: forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
lift :: forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
$clift :: forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
Lift)
instance Enum HalideTypeCode where
fromEnum :: HalideTypeCode -> Int
fromEnum :: HalideTypeCode -> Int
fromEnum HalideTypeCode
x = case HalideTypeCode
x of
HalideTypeCode
HalideTypeInt -> Int
0
HalideTypeCode
HalideTypeUInt -> Int
1
HalideTypeCode
HalideTypeFloat -> Int
2
HalideTypeCode
HalideTypeHandle -> Int
3
HalideTypeCode
HalideTypeBfloat -> Int
4
toEnum :: Int -> HalideTypeCode
toEnum :: Int -> HalideTypeCode
toEnum Int
x = case Int
x of
Int
0 -> HalideTypeCode
HalideTypeInt
Int
1 -> HalideTypeCode
HalideTypeUInt
Int
2 -> HalideTypeCode
HalideTypeFloat
Int
3 -> HalideTypeCode
HalideTypeHandle
Int
4 -> HalideTypeCode
HalideTypeBfloat
Int
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"invalid HalideTypeCode: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
x
data HalideType = HalideType
{ HalideType -> HalideTypeCode
halideTypeCode :: !HalideTypeCode
, HalideType -> Word8
halideTypeBits :: {-# UNPACK #-} !Word8
, HalideType -> Word16
halideTypeLanes :: {-# UNPACK #-} !Word16
}
deriving stock (ReadPrec [HalideType]
ReadPrec HalideType
Int -> ReadS HalideType
ReadS [HalideType]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideType]
$creadListPrec :: ReadPrec [HalideType]
readPrec :: ReadPrec HalideType
$creadPrec :: ReadPrec HalideType
readList :: ReadS [HalideType]
$creadList :: ReadS [HalideType]
readsPrec :: Int -> ReadS HalideType
$creadsPrec :: Int -> ReadS HalideType
Read, Int -> HalideType -> ShowS
[HalideType] -> ShowS
HalideType -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [HalideType] -> ShowS
$cshowList :: [HalideType] -> ShowS
show :: HalideType -> [Char]
$cshow :: HalideType -> [Char]
showsPrec :: Int -> HalideType -> ShowS
$cshowsPrec :: Int -> HalideType -> ShowS
Show, HalideType -> HalideType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideType -> HalideType -> Bool
$c/= :: HalideType -> HalideType -> Bool
== :: HalideType -> HalideType -> Bool
$c== :: HalideType -> HalideType -> Bool
Eq)
instance Storable HalideType where
sizeOf :: HalideType -> Int
sizeOf :: HalideType -> Int
sizeOf HalideType
_ = Int
4
alignment :: HalideType -> Int
alignment :: HalideType -> Int
alignment HalideType
_ = Int
4
peek :: Ptr HalideType -> IO HalideType
peek :: Ptr HalideType -> IO HalideType
peek Ptr HalideType
p =
HalideTypeCode -> Word8 -> Word16 -> HalideType
HalideType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Word8 -> Int) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
0)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
1
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
2
poke :: Ptr HalideType -> HalideType -> IO ()
poke :: Ptr HalideType -> HalideType -> IO ()
poke Ptr HalideType
p (HalideType HalideTypeCode
code Word8
bits Word16
lanes) = do
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int -> Word8) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ HalideTypeCode
code
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
1 Word8
bits
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
2 Word16
lanes
class Storable a => IsHalideType a where
halideTypeFor :: proxy a -> HalideType
toCxxExpr :: a -> IO (ForeignPtr CxxExpr)
optionallyCast :: String -> TH.TypeQ -> TH.ExpQ
optionallyCast :: [Char] -> TypeQ -> ExpQ
optionallyCast [Char]
cType TypeQ
hsType' = do
Type
hsType <- TypeQ
hsType'
Type
hsTargetType <- Bool -> [Char] -> TypeQ
C.getHaskellType Bool
False [Char]
cType
if Type
hsType forall a. Eq a => a -> a -> Bool
== Type
hsTargetType then [e|id|] else [e|coerce|]
instanceIsHalideType :: (String, TH.TypeQ, HalideTypeCode) -> TH.DecsQ
instanceIsHalideType :: ([Char], TypeQ, HalideTypeCode) -> DecsQ
instanceIsHalideType ([Char]
cType, TypeQ
hsType, HalideTypeCode
typeCode) =
forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
[([Char]
"T", \[Char]
x -> [Char]
"$(" forall a. Semigroup a => a -> a -> a
<> [Char]
cType forall a. Semigroup a => a -> a -> a
<> [Char]
" " forall a. Semigroup a => a -> a -> a
<> [Char]
x forall a. Semigroup a => a -> a -> a
<> [Char]
")")]
[d|
instance IsHalideType $hsType where
halideTypeFor _ = HalideType typeCode bits 1
where
bits = fromIntegral $ 8 * sizeOf (undefined :: $hsType)
toCxxExpr y =
cxxConstruct $ \ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{@T(x)} } |]
where
x = $(optionallyCast cType hsType) y
|]
defineIsHalideTypeInstances :: TH.DecsQ
defineIsHalideTypeInstances :: DecsQ
defineIsHalideTypeInstances = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char], TypeQ, HalideTypeCode) -> DecsQ
instanceIsHalideType [([Char], TypeQ, HalideTypeCode)]
halideTypes
instanceCxxConstructible :: String -> TH.DecsQ
instanceCxxConstructible :: [Char] -> DecsQ
instanceCxxConstructible [Char]
cType =
forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
[ ([Char]
"T", forall a b. a -> b -> a
const [Char]
cType)
, ([Char]
"Deleter", forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ [Char]
"deleter(" forall a. Semigroup a => a -> a -> a
<> [Char]
cType forall a. Semigroup a => a -> a -> a
<> [Char]
"* p)")
, ([Char]
"Class", forall a b. a -> b -> a
const forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Char]
T.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ Text -> Text -> (Text, Text)
T.breakOnEnd Text
"::" ([Char] -> Text
T.pack [Char]
cType))
]
[d|
instance CxxConstructible $(C.getHaskellType False cType) where
cxxSizeOf = fromIntegral [CU.pure| size_t { sizeof(@T()) } |]
cxxConstruct = cxxConstructWithDeleter size deleter
where
size = fromIntegral [CU.pure| size_t { sizeof(@T()) } |]
deleter = [C.funPtr| void @Deleter() { p->~@Class()(); } |]
|]
class HasCxxVector a where
newCxxVector :: Maybe Int -> IO (Ptr (CxxVector a))
deleteCxxVector :: Ptr (CxxVector a) -> IO ()
cxxVectorSize :: Ptr (CxxVector a) -> IO Int
cxxVectorPushBack :: Ptr (CxxVector a) -> Ptr a -> IO ()
cxxVectorData :: Ptr (CxxVector a) -> IO (Ptr a)
peekCxxVector :: Storable a => Ptr (CxxVector a) -> IO [a]
instanceHasCxxVector :: String -> TH.DecsQ
instanceHasCxxVector :: [Char] -> DecsQ
instanceHasCxxVector [Char]
cType =
forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
[ ([Char]
"T", forall a b. a -> b -> a
const [Char]
cType)
, ([Char]
"VEC", \[Char]
var -> [Char]
"$(std::vector<" forall a. [a] -> [a] -> [a]
++ [Char]
cType forall a. [a] -> [a] -> [a]
++ [Char]
">* " forall a. [a] -> [a] -> [a]
++ [Char]
var forall a. [a] -> [a] -> [a]
++ [Char]
")")
]
[d|
instance HasCxxVector $(C.getHaskellType False cType) where
newCxxVector maybeSize = do
v <- [CU.exp| std::vector<@T()>* { new std::vector<@T()>() } |]
case maybeSize of
Just size ->
let n = fromIntegral size
in [CU.exp| void { @VEC(v)->reserve($(size_t n)) } |]
Nothing -> pure ()
pure v
deleteCxxVector vec = [CU.exp| void { delete @VEC(vec) } |]
cxxVectorSize vec = fromIntegral <$> [CU.exp| size_t { @VEC(vec)->size() } |]
cxxVectorPushBack vec x = [CU.exp| void { @VEC(vec)->push_back(*$(@T()* x)) } |]
cxxVectorData vec = [CU.exp| @T()* { @VEC(vec)->data() } |]
peekCxxVector vec = do
n <- cxxVectorSize vec
allocaArray n $ \out -> do
[CU.block| void {
auto const& vec = *@VEC(vec);
auto* out = $(@T()* out);
std::uninitialized_copy(std::begin(vec), std::end(vec), out);
} |]
peekArray n out
|]
halideTypes :: [(String, TH.TypeQ, HalideTypeCode)]
halideTypes :: [([Char], TypeQ, HalideTypeCode)]
halideTypes =
[ ([Char]
"float", [t|Float|], HalideTypeCode
HalideTypeFloat)
, ([Char]
"float", [t|CFloat|], HalideTypeCode
HalideTypeFloat)
, ([Char]
"double", [t|Double|], HalideTypeCode
HalideTypeFloat)
, ([Char]
"double", [t|CDouble|], HalideTypeCode
HalideTypeFloat)
, ([Char]
"int8_t", [t|Int8|], HalideTypeCode
HalideTypeInt)
, ([Char]
"int16_t", [t|Int16|], HalideTypeCode
HalideTypeInt)
, ([Char]
"int32_t", [t|Int32|], HalideTypeCode
HalideTypeInt)
, ([Char]
"int64_t", [t|Int64|], HalideTypeCode
HalideTypeInt)
, ([Char]
"uint8_t", [t|Word8|], HalideTypeCode
HalideTypeUInt)
, ([Char]
"uint16_t", [t|Word16|], HalideTypeCode
HalideTypeUInt)
, ([Char]
"uint32_t", [t|Word32|], HalideTypeCode
HalideTypeUInt)
, ([Char]
"uint64_t", [t|Word64|], HalideTypeCode
HalideTypeUInt)
]
infixr 5 :::
data Arguments (k :: [Type]) where
Nil :: Arguments '[]
(:::) :: !t -> !(Arguments ts) -> Arguments (t ': ts)
type family Length (xs :: [k]) :: Nat where
Length '[] = 0
Length (x ': xs) = 1 + Length xs
type family Append (xs :: [k]) (y :: k) :: [k] where
Append '[] y = '[y]
Append (x ': xs) y = x ': Append xs y
type family Concat (xs :: [k]) (ys :: [k]) :: [k] where
Concat '[] ys = ys
Concat (x ': xs) ys = x ': Concat xs ys
argumentsAppend :: Arguments xs -> t -> Arguments (Append xs t)
argumentsAppend :: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
argumentsAppend = forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go
where
go :: forall xs t. Arguments xs -> t -> Arguments (Append xs t)
go :: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go Arguments xs
Nil t
y = t
y forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments '[]
Nil
go (t
x ::: Arguments ts
xs) t
y = t
x forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go Arguments ts
xs t
y
type family FunctionArguments (f :: Type) :: [Type] where
FunctionArguments (a -> b) = a ': FunctionArguments b
FunctionArguments a = '[]
type family FunctionReturn (f :: Type) :: Type where
FunctionReturn (a -> b) = FunctionReturn b
FunctionReturn a = a
type family All (c :: Type -> Constraint) (ts :: [Type]) = (p :: Constraint) | p -> ts where
All c '[] = ()
All c (t ': ts) = (c t, All c ts)
class UnCurry (f :: Type) (args :: [Type]) (r :: Type) | args r -> f, args f -> r where
uncurryG :: f -> Arguments args -> r
instance (FunctionArguments f ~ '[], FunctionReturn f ~ r, f ~ r) => UnCurry f '[] r where
uncurryG :: f -> Arguments '[] -> r
uncurryG f
f Arguments '[]
Nil = f
f
{-# INLINE uncurryG #-}
instance (UnCurry f args r) => UnCurry (a -> f) (a ': args) r where
uncurryG :: (a -> f) -> Arguments (a : args) -> r
uncurryG a -> f
f (t
a ::: Arguments ts
args) = forall f (args :: [*]) r.
UnCurry f args r =>
f -> Arguments args -> r
uncurryG (a -> f
f t
a) Arguments ts
args
{-# INLINE uncurryG #-}
class Curry (args :: [Type]) (r :: Type) (f :: Type) | args r -> f where
curryG :: (Arguments args -> r) -> f
instance Curry '[] r r where
curryG :: (Arguments '[] -> r) -> r
curryG Arguments '[] -> r
f = Arguments '[] -> r
f Arguments '[]
Nil
{-# INLINE curryG #-}
instance Curry args r f => Curry (a ': args) r (a -> f) where
curryG :: (Arguments (a : args) -> r) -> a -> f
curryG Arguments (a : args) -> r
f a
a = forall (args :: [*]) r f.
Curry args r f =>
(Arguments args -> r) -> f
curryG (\Arguments args
args -> Arguments (a : args) -> r
f (a
a forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments args
args))