{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unused-local-binds -Wno-unused-matches #-}
module Language.Halide.Type
  ( HalideTypeCode (..)
  , HalideType (..)
  , IsHalideType (..)
  , CxxExpr
  , CxxVar
  , CxxRVar
  , CxxVarOrRVar
  , CxxFunc
  , CxxParameter
  , CxxImageParam
  , CxxVector
  , CxxUserContext
  , CxxCallable
  , CxxTarget
  , CxxStageSchedule
  , CxxString
  , Arguments (..)
  , Length
  , Append
  , Concat
  , argumentsAppend
  , FunctionArguments
  , FunctionReturn
  , All
  , UnCurry (..)
  , Curry (..)
  , defineIsHalideTypeInstances
  , instanceHasCxxVector
  , HasCxxVector (..)
  , instanceCxxConstructible
  , CxxConstructible (..)
  
  
  
  
  
  )
where
import Data.Coerce
import Data.Constraint
import Data.Int
import Data.Kind (Type)
import Data.Text qualified as T
import Data.Word
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.ForeignPtr (mallocForeignPtrAlignedBytes)
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Haskell.TH qualified as TH
import Language.Haskell.TH.Syntax (Lift)
data CxxExpr
data CxxVar
data CxxRVar
data CxxVarOrRVar
data CxxParameter
data CxxImageParam
data CxxFunc
data CxxUserContext
data CxxCallable
data CxxTarget
data CxxVector a
data CxxStageSchedule
data CxxString
class CxxConstructible a where
  cxxSizeOf :: Int
  cxxConstruct :: (Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstructWithDeleter :: Int -> FinalizerPtr a -> (Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstructWithDeleter :: forall a.
Int -> FinalizerPtr a -> (Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstructWithDeleter Int
size FinalizerPtr a
deleter Ptr a -> IO ()
constructor = do
  ForeignPtr a
fp <- forall a. Int -> Int -> IO (ForeignPtr a)
mallocForeignPtrAlignedBytes Int
size Int
align
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fp Ptr a -> IO ()
constructor
  forall a. FinalizerPtr a -> ForeignPtr a -> IO ()
addForeignPtrFinalizer FinalizerPtr a
deleter ForeignPtr a
fp
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr a
fp
  where
    align :: Int
align = Int
64
data HalideTypeCode
  = HalideTypeInt
  | HalideTypeUInt
  | HalideTypeFloat
  | HalideTypeHandle
  | HalideTypeBfloat
  deriving stock (ReadPrec [HalideTypeCode]
ReadPrec HalideTypeCode
Int -> ReadS HalideTypeCode
ReadS [HalideTypeCode]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideTypeCode]
$creadListPrec :: ReadPrec [HalideTypeCode]
readPrec :: ReadPrec HalideTypeCode
$creadPrec :: ReadPrec HalideTypeCode
readList :: ReadS [HalideTypeCode]
$creadList :: ReadS [HalideTypeCode]
readsPrec :: Int -> ReadS HalideTypeCode
$creadsPrec :: Int -> ReadS HalideTypeCode
Read, Int -> HalideTypeCode -> ShowS
[HalideTypeCode] -> ShowS
HalideTypeCode -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [HalideTypeCode] -> ShowS
$cshowList :: [HalideTypeCode] -> ShowS
show :: HalideTypeCode -> [Char]
$cshow :: HalideTypeCode -> [Char]
showsPrec :: Int -> HalideTypeCode -> ShowS
$cshowsPrec :: Int -> HalideTypeCode -> ShowS
Show, HalideTypeCode -> HalideTypeCode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideTypeCode -> HalideTypeCode -> Bool
$c/= :: HalideTypeCode -> HalideTypeCode -> Bool
== :: HalideTypeCode -> HalideTypeCode -> Bool
$c== :: HalideTypeCode -> HalideTypeCode -> Bool
Eq, forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
liftTyped :: forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
$cliftTyped :: forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
lift :: forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
$clift :: forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
Lift)
instance Enum HalideTypeCode where
  fromEnum :: HalideTypeCode -> Int
  fromEnum :: HalideTypeCode -> Int
fromEnum HalideTypeCode
x = case HalideTypeCode
x of
    HalideTypeCode
HalideTypeInt -> Int
0
    HalideTypeCode
HalideTypeUInt -> Int
1
    HalideTypeCode
HalideTypeFloat -> Int
2
    HalideTypeCode
HalideTypeHandle -> Int
3
    HalideTypeCode
HalideTypeBfloat -> Int
4
  toEnum :: Int -> HalideTypeCode
  toEnum :: Int -> HalideTypeCode
toEnum Int
x = case Int
x of
    Int
0 -> HalideTypeCode
HalideTypeInt
    Int
1 -> HalideTypeCode
HalideTypeUInt
    Int
2 -> HalideTypeCode
HalideTypeFloat
    Int
3 -> HalideTypeCode
HalideTypeHandle
    Int
4 -> HalideTypeCode
HalideTypeBfloat
    Int
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"invalid HalideTypeCode: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
x
data HalideType = HalideType
  { HalideType -> HalideTypeCode
halideTypeCode :: !HalideTypeCode
  , HalideType -> Word8
halideTypeBits :: {-# UNPACK #-} !Word8
  , HalideType -> Word16
halideTypeLanes :: {-# UNPACK #-} !Word16
  }
  deriving stock (ReadPrec [HalideType]
ReadPrec HalideType
Int -> ReadS HalideType
ReadS [HalideType]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideType]
$creadListPrec :: ReadPrec [HalideType]
readPrec :: ReadPrec HalideType
$creadPrec :: ReadPrec HalideType
readList :: ReadS [HalideType]
$creadList :: ReadS [HalideType]
readsPrec :: Int -> ReadS HalideType
$creadsPrec :: Int -> ReadS HalideType
Read, Int -> HalideType -> ShowS
[HalideType] -> ShowS
HalideType -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [HalideType] -> ShowS
$cshowList :: [HalideType] -> ShowS
show :: HalideType -> [Char]
$cshow :: HalideType -> [Char]
showsPrec :: Int -> HalideType -> ShowS
$cshowsPrec :: Int -> HalideType -> ShowS
Show, HalideType -> HalideType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideType -> HalideType -> Bool
$c/= :: HalideType -> HalideType -> Bool
== :: HalideType -> HalideType -> Bool
$c== :: HalideType -> HalideType -> Bool
Eq)
instance Storable HalideType where
  sizeOf :: HalideType -> Int
  sizeOf :: HalideType -> Int
sizeOf HalideType
_ = Int
4
  alignment :: HalideType -> Int
  alignment :: HalideType -> Int
alignment HalideType
_ = Int
4
  peek :: Ptr HalideType -> IO HalideType
  peek :: Ptr HalideType -> IO HalideType
peek Ptr HalideType
p =
    HalideTypeCode -> Word8 -> Word16 -> HalideType
HalideType
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Word8 -> Int) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
0)
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
1
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
2
  poke :: Ptr HalideType -> HalideType -> IO ()
  poke :: Ptr HalideType -> HalideType -> IO ()
poke Ptr HalideType
p (HalideType HalideTypeCode
code Word8
bits Word16
lanes) = do
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int -> Word8) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ HalideTypeCode
code
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
1 Word8
bits
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
2 Word16
lanes
class Storable a => IsHalideType a where
  halideTypeFor :: proxy a -> HalideType
  toCxxExpr :: a -> IO (ForeignPtr CxxExpr)
optionallyCast :: String -> TH.TypeQ -> TH.ExpQ
optionallyCast :: [Char] -> TypeQ -> ExpQ
optionallyCast [Char]
cType TypeQ
hsType' = do
  Type
hsType <- TypeQ
hsType'
  Type
hsTargetType <- Bool -> [Char] -> TypeQ
C.getHaskellType Bool
False [Char]
cType
  if Type
hsType forall a. Eq a => a -> a -> Bool
== Type
hsTargetType then [e|id|] else [e|coerce|]
instanceIsHalideType :: (String, TH.TypeQ, HalideTypeCode) -> TH.DecsQ
instanceIsHalideType :: ([Char], TypeQ, HalideTypeCode) -> DecsQ
instanceIsHalideType ([Char]
cType, TypeQ
hsType, HalideTypeCode
typeCode) =
  forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
    [([Char]
"T", \[Char]
x -> [Char]
"$(" forall a. Semigroup a => a -> a -> a
<> [Char]
cType forall a. Semigroup a => a -> a -> a
<> [Char]
" " forall a. Semigroup a => a -> a -> a
<> [Char]
x forall a. Semigroup a => a -> a -> a
<> [Char]
")")]
    [d|
      instance IsHalideType $hsType where
        halideTypeFor _ = HalideType typeCode bits 1
          where
            bits = fromIntegral $ 8 * sizeOf (undefined :: $hsType)
        toCxxExpr y =
          cxxConstruct $ \ptr ->
            [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{@T(x)} } |]
          where
            x = $(optionallyCast cType hsType) y
      |]
defineIsHalideTypeInstances :: TH.DecsQ
defineIsHalideTypeInstances :: DecsQ
defineIsHalideTypeInstances = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char], TypeQ, HalideTypeCode) -> DecsQ
instanceIsHalideType [([Char], TypeQ, HalideTypeCode)]
halideTypes
instanceCxxConstructible :: String -> TH.DecsQ
instanceCxxConstructible :: [Char] -> DecsQ
instanceCxxConstructible [Char]
cType =
  forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
    [ ([Char]
"T", forall a b. a -> b -> a
const [Char]
cType)
    , ([Char]
"Deleter", forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ [Char]
"deleter(" forall a. Semigroup a => a -> a -> a
<> [Char]
cType forall a. Semigroup a => a -> a -> a
<> [Char]
"* p)")
    , ([Char]
"Class", forall a b. a -> b -> a
const forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Char]
T.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ Text -> Text -> (Text, Text)
T.breakOnEnd Text
"::" ([Char] -> Text
T.pack [Char]
cType))
    ]
    [d|
      instance CxxConstructible $(C.getHaskellType False cType) where
        cxxSizeOf = fromIntegral [CU.pure| size_t { sizeof(@T()) } |]
        cxxConstruct = cxxConstructWithDeleter size deleter
          where
            size = fromIntegral [CU.pure| size_t { sizeof(@T()) } |]
            deleter = [C.funPtr| void @Deleter() { p->~@Class()(); } |]
      |]
class HasCxxVector a where
  newCxxVector :: Maybe Int -> IO (Ptr (CxxVector a))
  deleteCxxVector :: Ptr (CxxVector a) -> IO ()
  cxxVectorSize :: Ptr (CxxVector a) -> IO Int
  cxxVectorPushBack :: Ptr (CxxVector a) -> Ptr a -> IO ()
  cxxVectorData :: Ptr (CxxVector a) -> IO (Ptr a)
  peekCxxVector :: Storable a => Ptr (CxxVector a) -> IO [a]
instanceHasCxxVector :: String -> TH.DecsQ
instanceHasCxxVector :: [Char] -> DecsQ
instanceHasCxxVector [Char]
cType =
  forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
    [ ([Char]
"T", forall a b. a -> b -> a
const [Char]
cType)
    , ([Char]
"VEC", \[Char]
var -> [Char]
"$(std::vector<" forall a. [a] -> [a] -> [a]
++ [Char]
cType forall a. [a] -> [a] -> [a]
++ [Char]
">* " forall a. [a] -> [a] -> [a]
++ [Char]
var forall a. [a] -> [a] -> [a]
++ [Char]
")")
    ]
    [d|
      instance HasCxxVector $(C.getHaskellType False cType) where
        newCxxVector maybeSize = do
          v <- [CU.exp| std::vector<@T()>* { new std::vector<@T()>() } |]
          case maybeSize of
            Just size ->
              let n = fromIntegral size
               in [CU.exp| void { @VEC(v)->reserve($(size_t n)) } |]
            Nothing -> pure ()
          pure v
        deleteCxxVector vec = [CU.exp| void { delete @VEC(vec) } |]
        cxxVectorSize vec = fromIntegral <$> [CU.exp| size_t { @VEC(vec)->size() } |]
        cxxVectorPushBack vec x = [CU.exp| void { @VEC(vec)->push_back(*$(@T()* x)) } |]
        cxxVectorData vec = [CU.exp| @T()* { @VEC(vec)->data() } |]
        peekCxxVector vec = do
          n <- cxxVectorSize vec
          allocaArray n $ \out -> do
            [CU.block| void {
              auto const& vec = *@VEC(vec);
              auto* out = $(@T()* out);
              std::uninitialized_copy(std::begin(vec), std::end(vec), out);
            } |]
            peekArray n out
      |]
halideTypes :: [(String, TH.TypeQ, HalideTypeCode)]
halideTypes :: [([Char], TypeQ, HalideTypeCode)]
halideTypes =
  [ ([Char]
"float", [t|Float|], HalideTypeCode
HalideTypeFloat)
  , ([Char]
"float", [t|CFloat|], HalideTypeCode
HalideTypeFloat)
  , ([Char]
"double", [t|Double|], HalideTypeCode
HalideTypeFloat)
  , ([Char]
"double", [t|CDouble|], HalideTypeCode
HalideTypeFloat)
  , ([Char]
"int8_t", [t|Int8|], HalideTypeCode
HalideTypeInt)
  , ([Char]
"int16_t", [t|Int16|], HalideTypeCode
HalideTypeInt)
  , ([Char]
"int32_t", [t|Int32|], HalideTypeCode
HalideTypeInt)
  , ([Char]
"int64_t", [t|Int64|], HalideTypeCode
HalideTypeInt)
  , ([Char]
"uint8_t", [t|Word8|], HalideTypeCode
HalideTypeUInt)
  , ([Char]
"uint16_t", [t|Word16|], HalideTypeCode
HalideTypeUInt)
  , ([Char]
"uint32_t", [t|Word32|], HalideTypeCode
HalideTypeUInt)
  , ([Char]
"uint64_t", [t|Word64|], HalideTypeCode
HalideTypeUInt)
  ]
infixr 5 :::
data Arguments (k :: [Type]) where
  Nil :: Arguments '[]
  (:::) :: !t -> !(Arguments ts) -> Arguments (t ': ts)
type family Length (xs :: [k]) :: Nat where
  Length '[] = 0
  Length (x ': xs) = 1 + Length xs
type family Append (xs :: [k]) (y :: k) :: [k] where
  Append '[] y = '[y]
  Append (x ': xs) y = x ': Append xs y
type family Concat (xs :: [k]) (ys :: [k]) :: [k] where
  Concat '[] ys = ys
  Concat (x ': xs) ys = x ': Concat xs ys
argumentsAppend :: Arguments xs -> t -> Arguments (Append xs t)
argumentsAppend :: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
argumentsAppend = forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go
  where
    go :: forall xs t. Arguments xs -> t -> Arguments (Append xs t)
    go :: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go Arguments xs
Nil t
y = t
y forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments '[]
Nil
    go (t
x ::: Arguments ts
xs) t
y = t
x forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go Arguments ts
xs t
y
type family FunctionArguments (f :: Type) :: [Type] where
  FunctionArguments (a -> b) = a ': FunctionArguments b
  FunctionArguments a = '[]
type family FunctionReturn (f :: Type) :: Type where
  FunctionReturn (a -> b) = FunctionReturn b
  FunctionReturn a = a
type family All (c :: Type -> Constraint) (ts :: [Type]) = (p :: Constraint) | p -> ts where
  All c '[] = ()
  All c (t ': ts) = (c t, All c ts)
class UnCurry (f :: Type) (args :: [Type]) (r :: Type) | args r -> f, args f -> r where
  uncurryG :: f -> Arguments args -> r
instance (FunctionArguments f ~ '[], FunctionReturn f ~ r, f ~ r) => UnCurry f '[] r where
  uncurryG :: f -> Arguments '[] -> r
uncurryG f
f Arguments '[]
Nil = f
f
  {-# INLINE uncurryG #-}
instance (UnCurry f args r) => UnCurry (a -> f) (a ': args) r where
  uncurryG :: (a -> f) -> Arguments (a : args) -> r
uncurryG a -> f
f (t
a ::: Arguments ts
args) = forall f (args :: [*]) r.
UnCurry f args r =>
f -> Arguments args -> r
uncurryG (a -> f
f t
a) Arguments ts
args
  {-# INLINE uncurryG #-}
class Curry (args :: [Type]) (r :: Type) (f :: Type) | args r -> f where
  curryG :: (Arguments args -> r) -> f
instance Curry '[] r r where
  curryG :: (Arguments '[] -> r) -> r
curryG Arguments '[] -> r
f = Arguments '[] -> r
f Arguments '[]
Nil
  {-# INLINE curryG #-}
instance Curry args r f => Curry (a ': args) r (a -> f) where
  curryG :: (Arguments (a : args) -> r) -> a -> f
curryG Arguments (a : args) -> r
f a
a = forall (args :: [*]) r f.
Curry args r f =>
(Arguments args -> r) -> f
curryG (\Arguments args
args -> Arguments (a : args) -> r
f (a
a forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments args
args))