{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unused-local-binds -Wno-unused-matches #-}

-- |
-- Module      : Language.Halide.Type
-- Description : Low-level types
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.Type
  ( HalideTypeCode (..)
  , HalideType (..)
  , IsHalideType (..)
  , CxxExpr
  , CxxVar
  , CxxRVar
  , CxxVarOrRVar
  , CxxFunc
  , CxxParameter
  , CxxImageParam
  , CxxVector
  , CxxUserContext
  , CxxCallable
  , CxxTarget
  , CxxStageSchedule
  , CxxString
  , Arguments (..)
  , Length
  , Append
  , Concat
  , argumentsAppend
  , FunctionArguments
  , FunctionReturn
  , All
  , UnCurry (..)
  , Curry (..)
  , defineIsHalideTypeInstances
  , instanceHasCxxVector
  , HasCxxVector (..)
  , instanceCxxConstructible
  , CxxConstructible (..)
  -- defineCastableInstances,
  -- defineCurriedTypeFamily,
  -- defineUnCurriedTypeFamily,
  -- defineCurryInstances,
  -- defineUnCurryInstances,
  )
where

import Data.Coerce
import Data.Constraint
import Data.Int
import Data.Kind (Type)
import Data.Text qualified as T
import Data.Word
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.ForeignPtr (mallocForeignPtrAlignedBytes)
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Haskell.TH qualified as TH
import Language.Haskell.TH.Syntax (Lift)

-- | Haskell counterpart of @Halide::Expr@.
data CxxExpr

-- | Haskell counterpart of @Halide::Var@.
data CxxVar

-- | Haskell counterpart of @Halide::RVar@.
data CxxRVar

-- | Haskell counterpart of @Halide::VarOrRVar@.
data CxxVarOrRVar

-- | Haskell counterpart of @Halide::Internal::Parameter@.
data CxxParameter

-- | Haskell counterpart of @Halide::ImageParam@.
data CxxImageParam

-- | Haskell counterpart of @Halide::Func@.
data CxxFunc

-- | Haskell counterpart of @Halide::JITUserContext@.
data CxxUserContext

-- | Haskell counterpart of @Halide::Callable@.
data CxxCallable

-- | Haskell counterpart of @Halide::Target@.
data CxxTarget

-- | Haskell counterpart of @std::vector@.
data CxxVector a

-- | Haskell counterpart of @Halide::Internal::StageSchedule@.
data CxxStageSchedule

-- | Haskell counterpart of @std::string@
data CxxString

class CxxConstructible a where
  cxxSizeOf :: Int
  cxxConstruct :: (Ptr a -> IO ()) -> IO (ForeignPtr a)

cxxConstructWithDeleter :: Int -> FinalizerPtr a -> (Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstructWithDeleter :: forall a.
Int -> FinalizerPtr a -> (Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstructWithDeleter Int
size FinalizerPtr a
deleter Ptr a -> IO ()
constructor = do
  ForeignPtr a
fp <- forall a. Int -> Int -> IO (ForeignPtr a)
mallocForeignPtrAlignedBytes Int
size Int
align
  forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fp Ptr a -> IO ()
constructor
  forall a. FinalizerPtr a -> ForeignPtr a -> IO ()
addForeignPtrFinalizer FinalizerPtr a
deleter ForeignPtr a
fp
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr a
fp
  where
    align :: Int
align = Int
64

-- data Split =
--   SplitVar !Text !Text !Text !(Expr Int32) !

-- | Haskell counterpart of @halide_type_code_t@.
data HalideTypeCode
  = HalideTypeInt
  | HalideTypeUInt
  | HalideTypeFloat
  | HalideTypeHandle
  | HalideTypeBfloat
  deriving stock (ReadPrec [HalideTypeCode]
ReadPrec HalideTypeCode
Int -> ReadS HalideTypeCode
ReadS [HalideTypeCode]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideTypeCode]
$creadListPrec :: ReadPrec [HalideTypeCode]
readPrec :: ReadPrec HalideTypeCode
$creadPrec :: ReadPrec HalideTypeCode
readList :: ReadS [HalideTypeCode]
$creadList :: ReadS [HalideTypeCode]
readsPrec :: Int -> ReadS HalideTypeCode
$creadsPrec :: Int -> ReadS HalideTypeCode
Read, Int -> HalideTypeCode -> ShowS
[HalideTypeCode] -> ShowS
HalideTypeCode -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [HalideTypeCode] -> ShowS
$cshowList :: [HalideTypeCode] -> ShowS
show :: HalideTypeCode -> [Char]
$cshow :: HalideTypeCode -> [Char]
showsPrec :: Int -> HalideTypeCode -> ShowS
$cshowsPrec :: Int -> HalideTypeCode -> ShowS
Show, HalideTypeCode -> HalideTypeCode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideTypeCode -> HalideTypeCode -> Bool
$c/= :: HalideTypeCode -> HalideTypeCode -> Bool
== :: HalideTypeCode -> HalideTypeCode -> Bool
$c== :: HalideTypeCode -> HalideTypeCode -> Bool
Eq, forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
liftTyped :: forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
$cliftTyped :: forall (m :: * -> *).
Quote m =>
HalideTypeCode -> Code m HalideTypeCode
lift :: forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
$clift :: forall (m :: * -> *). Quote m => HalideTypeCode -> m Exp
Lift)

instance Enum HalideTypeCode where
  fromEnum :: HalideTypeCode -> Int
  fromEnum :: HalideTypeCode -> Int
fromEnum HalideTypeCode
x = case HalideTypeCode
x of
    HalideTypeCode
HalideTypeInt -> Int
0
    HalideTypeCode
HalideTypeUInt -> Int
1
    HalideTypeCode
HalideTypeFloat -> Int
2
    HalideTypeCode
HalideTypeHandle -> Int
3
    HalideTypeCode
HalideTypeBfloat -> Int
4
  toEnum :: Int -> HalideTypeCode
  toEnum :: Int -> HalideTypeCode
toEnum Int
x = case Int
x of
    Int
0 -> HalideTypeCode
HalideTypeInt
    Int
1 -> HalideTypeCode
HalideTypeUInt
    Int
2 -> HalideTypeCode
HalideTypeFloat
    Int
3 -> HalideTypeCode
HalideTypeHandle
    Int
4 -> HalideTypeCode
HalideTypeBfloat
    Int
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"invalid HalideTypeCode: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
x

-- | Haskell counterpart of @halide_type_t@.
data HalideType = HalideType
  { HalideType -> HalideTypeCode
halideTypeCode :: !HalideTypeCode
  , HalideType -> Word8
halideTypeBits :: {-# UNPACK #-} !Word8
  , HalideType -> Word16
halideTypeLanes :: {-# UNPACK #-} !Word16
  }
  deriving stock (ReadPrec [HalideType]
ReadPrec HalideType
Int -> ReadS HalideType
ReadS [HalideType]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HalideType]
$creadListPrec :: ReadPrec [HalideType]
readPrec :: ReadPrec HalideType
$creadPrec :: ReadPrec HalideType
readList :: ReadS [HalideType]
$creadList :: ReadS [HalideType]
readsPrec :: Int -> ReadS HalideType
$creadsPrec :: Int -> ReadS HalideType
Read, Int -> HalideType -> ShowS
[HalideType] -> ShowS
HalideType -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [HalideType] -> ShowS
$cshowList :: [HalideType] -> ShowS
show :: HalideType -> [Char]
$cshow :: HalideType -> [Char]
showsPrec :: Int -> HalideType -> ShowS
$cshowsPrec :: Int -> HalideType -> ShowS
Show, HalideType -> HalideType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HalideType -> HalideType -> Bool
$c/= :: HalideType -> HalideType -> Bool
== :: HalideType -> HalideType -> Bool
$c== :: HalideType -> HalideType -> Bool
Eq)

instance Storable HalideType where
  sizeOf :: HalideType -> Int
  sizeOf :: HalideType -> Int
sizeOf HalideType
_ = Int
4
  alignment :: HalideType -> Int
  alignment :: HalideType -> Int
alignment HalideType
_ = Int
4
  peek :: Ptr HalideType -> IO HalideType
  peek :: Ptr HalideType -> IO HalideType
peek Ptr HalideType
p =
    HalideTypeCode -> Word8 -> Word16 -> HalideType
HalideType
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Word8 -> Int) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
0)
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
1
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr HalideType
p Int
2
  poke :: Ptr HalideType -> HalideType -> IO ()
  poke :: Ptr HalideType -> HalideType -> IO ()
poke Ptr HalideType
p (HalideType HalideTypeCode
code Word8
bits Word16
lanes) = do
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int -> Word8) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ HalideTypeCode
code
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
1 Word8
bits
    forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr HalideType
p Int
2 Word16
lanes

-- | Specifies that a type is supported by Halide.
class Storable a => IsHalideType a where
  halideTypeFor :: proxy a -> HalideType
  toCxxExpr :: a -> IO (ForeignPtr CxxExpr)

-- | Helper function to coerce 'Float' to 'CFloat' and 'Double' to 'CDouble'
-- before passing them to inline-c quasiquotes. This is needed because inline-c
-- assumes that @float@ in C corresponds to 'CFloat' in Haskell.
optionallyCast :: String -> TH.TypeQ -> TH.ExpQ
optionallyCast :: [Char] -> TypeQ -> ExpQ
optionallyCast [Char]
cType TypeQ
hsType' = do
  Type
hsType <- TypeQ
hsType'
  Type
hsTargetType <- Bool -> [Char] -> TypeQ
C.getHaskellType Bool
False [Char]
cType
  if Type
hsType forall a. Eq a => a -> a -> Bool
== Type
hsTargetType then [e|id|] else [e|coerce|]

-- | Template Haskell splice that defines instances of 'IsHalideType' for a
-- given Haskell type.
instanceIsHalideType :: (String, TH.TypeQ, HalideTypeCode) -> TH.DecsQ
instanceIsHalideType :: ([Char], TypeQ, HalideTypeCode) -> DecsQ
instanceIsHalideType ([Char]
cType, TypeQ
hsType, HalideTypeCode
typeCode) =
  forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
    [([Char]
"T", \[Char]
x -> [Char]
"$(" forall a. Semigroup a => a -> a -> a
<> [Char]
cType forall a. Semigroup a => a -> a -> a
<> [Char]
" " forall a. Semigroup a => a -> a -> a
<> [Char]
x forall a. Semigroup a => a -> a -> a
<> [Char]
")")]
    [d|
      instance IsHalideType $hsType where
        halideTypeFor _ = HalideType typeCode bits 1
          where
            bits = fromIntegral $ 8 * sizeOf (undefined :: $hsType)
        toCxxExpr y =
          cxxConstruct $ \ptr ->
            [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{@T(x)} } |]
          where
            x = $(optionallyCast cType hsType) y
      |]

-- | Derive 'IsHalideType' instances for all supported types.
defineIsHalideTypeInstances :: TH.DecsQ
defineIsHalideTypeInstances :: DecsQ
defineIsHalideTypeInstances = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char], TypeQ, HalideTypeCode) -> DecsQ
instanceIsHalideType [([Char], TypeQ, HalideTypeCode)]
halideTypes

instanceCxxConstructible :: String -> TH.DecsQ
instanceCxxConstructible :: [Char] -> DecsQ
instanceCxxConstructible [Char]
cType =
  forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
    [ ([Char]
"T", forall a b. a -> b -> a
const [Char]
cType)
    , ([Char]
"Deleter", forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ [Char]
"deleter(" forall a. Semigroup a => a -> a -> a
<> [Char]
cType forall a. Semigroup a => a -> a -> a
<> [Char]
"* p)")
    , ([Char]
"Class", forall a b. a -> b -> a
const forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Char]
T.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ Text -> Text -> (Text, Text)
T.breakOnEnd Text
"::" ([Char] -> Text
T.pack [Char]
cType))
    ]
    [d|
      instance CxxConstructible $(C.getHaskellType False cType) where
        cxxSizeOf = fromIntegral [CU.pure| size_t { sizeof(@T()) } |]
        cxxConstruct = cxxConstructWithDeleter size deleter
          where
            size = fromIntegral [CU.pure| size_t { sizeof(@T()) } |]
            deleter = [C.funPtr| void @Deleter() { p->~@Class()(); } |]
      |]

-- | Specifies that a given Haskell type can be used with @std::vector@.
--
-- E.g. if we have @HasCxxVector Int16@, then using @std::vector<int16_t>*@
-- in inline-c quotes will work.
class HasCxxVector a where
  newCxxVector :: Maybe Int -> IO (Ptr (CxxVector a))
  deleteCxxVector :: Ptr (CxxVector a) -> IO ()
  cxxVectorSize :: Ptr (CxxVector a) -> IO Int
  cxxVectorPushBack :: Ptr (CxxVector a) -> Ptr a -> IO ()
  cxxVectorData :: Ptr (CxxVector a) -> IO (Ptr a)
  peekCxxVector :: Storable a => Ptr (CxxVector a) -> IO [a]

-- | Template Haskell splice that defines an instance of 'HasCxxVector' for a given C type name.
instanceHasCxxVector :: String -> TH.DecsQ
instanceHasCxxVector :: [Char] -> DecsQ
instanceHasCxxVector [Char]
cType =
  forall a. [([Char], ShowS)] -> Q a -> Q a
C.substitute
    [ ([Char]
"T", forall a b. a -> b -> a
const [Char]
cType)
    , ([Char]
"VEC", \[Char]
var -> [Char]
"$(std::vector<" forall a. [a] -> [a] -> [a]
++ [Char]
cType forall a. [a] -> [a] -> [a]
++ [Char]
">* " forall a. [a] -> [a] -> [a]
++ [Char]
var forall a. [a] -> [a] -> [a]
++ [Char]
")")
    ]
    [d|
      instance HasCxxVector $(C.getHaskellType False cType) where
        newCxxVector maybeSize = do
          v <- [CU.exp| std::vector<@T()>* { new std::vector<@T()>() } |]
          case maybeSize of
            Just size ->
              let n = fromIntegral size
               in [CU.exp| void { @VEC(v)->reserve($(size_t n)) } |]
            Nothing -> pure ()
          pure v
        deleteCxxVector vec = [CU.exp| void { delete @VEC(vec) } |]
        cxxVectorSize vec = fromIntegral <$> [CU.exp| size_t { @VEC(vec)->size() } |]
        cxxVectorPushBack vec x = [CU.exp| void { @VEC(vec)->push_back(*$(@T()* x)) } |]
        cxxVectorData vec = [CU.exp| @T()* { @VEC(vec)->data() } |]
        peekCxxVector vec = do
          n <- cxxVectorSize vec
          allocaArray n $ \out -> do
            [CU.block| void {
              auto const& vec = *@VEC(vec);
              auto* out = $(@T()* out);
              std::uninitialized_copy(std::begin(vec), std::end(vec), out);
            } |]
            peekArray n out
      |]

-- | List of all supported types.
halideTypes :: [(String, TH.TypeQ, HalideTypeCode)]
halideTypes :: [([Char], TypeQ, HalideTypeCode)]
halideTypes =
  [ ([Char]
"float", [t|Float|], HalideTypeCode
HalideTypeFloat)
  , ([Char]
"float", [t|CFloat|], HalideTypeCode
HalideTypeFloat)
  , ([Char]
"double", [t|Double|], HalideTypeCode
HalideTypeFloat)
  , ([Char]
"double", [t|CDouble|], HalideTypeCode
HalideTypeFloat)
  , ([Char]
"int8_t", [t|Int8|], HalideTypeCode
HalideTypeInt)
  , ([Char]
"int16_t", [t|Int16|], HalideTypeCode
HalideTypeInt)
  , ([Char]
"int32_t", [t|Int32|], HalideTypeCode
HalideTypeInt)
  , ([Char]
"int64_t", [t|Int64|], HalideTypeCode
HalideTypeInt)
  , ([Char]
"uint8_t", [t|Word8|], HalideTypeCode
HalideTypeUInt)
  , ([Char]
"uint16_t", [t|Word16|], HalideTypeCode
HalideTypeUInt)
  , ([Char]
"uint32_t", [t|Word32|], HalideTypeCode
HalideTypeUInt)
  , ([Char]
"uint64_t", [t|Word64|], HalideTypeCode
HalideTypeUInt)
  ]

infixr 5 :::

-- | A heterogeneous list.
data Arguments (k :: [Type]) where
  Nil :: Arguments '[]
  (:::) :: !t -> !(Arguments ts) -> Arguments (t ': ts)

-- | A type family that returns the length of a type-level list.
type family Length (xs :: [k]) :: Nat where
  Length '[] = 0
  Length (x ': xs) = 1 + Length xs

-- | Append to a type-level list.
type family Append (xs :: [k]) (y :: k) :: [k] where
  Append '[] y = '[y]
  Append (x ': xs) y = x ': Append xs y

type family Concat (xs :: [k]) (ys :: [k]) :: [k] where
  Concat '[] ys = ys
  Concat (x ': xs) ys = x ': Concat xs ys

-- | Append a value to 'Arguments'
argumentsAppend :: Arguments xs -> t -> Arguments (Append xs t)
argumentsAppend :: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
argumentsAppend = forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go
  where
    go :: forall xs t. Arguments xs -> t -> Arguments (Append xs t)
    go :: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go Arguments xs
Nil t
y = t
y forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments '[]
Nil
    go (t
x ::: Arguments ts
xs) t
y = t
x forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: forall (xs :: [*]) t. Arguments xs -> t -> Arguments (Append xs t)
go Arguments ts
xs t
y

-- | Return the list of arguments to of a function type.
type family FunctionArguments (f :: Type) :: [Type] where
  FunctionArguments (a -> b) = a ': FunctionArguments b
  FunctionArguments a = '[]

-- | Get the return type of a function.
type family FunctionReturn (f :: Type) :: Type where
  FunctionReturn (a -> b) = FunctionReturn b
  FunctionReturn a = a

-- | Apply constraint to all types in a list.
type family All (c :: Type -> Constraint) (ts :: [Type]) = (p :: Constraint) | p -> ts where
  All c '[] = ()
  All c (t ': ts) = (c t, All c ts)

-- | A helper typeclass to convert a normal curried function to a function that
-- takes 'Arguments' as input.
--
-- For instance, if we have a function @f :: Int -> Float -> Double@, then it
-- will be converted to @f' :: Arguments '[Int, Float] -> Double@.
class UnCurry (f :: Type) (args :: [Type]) (r :: Type) | args r -> f, args f -> r where
  uncurryG :: f -> Arguments args -> r

instance (FunctionArguments f ~ '[], FunctionReturn f ~ r, f ~ r) => UnCurry f '[] r where
  uncurryG :: f -> Arguments '[] -> r
uncurryG f
f Arguments '[]
Nil = f
f
  {-# INLINE uncurryG #-}

instance (UnCurry f args r) => UnCurry (a -> f) (a ': args) r where
  uncurryG :: (a -> f) -> Arguments (a : args) -> r
uncurryG a -> f
f (t
a ::: Arguments ts
args) = forall f (args :: [*]) r.
UnCurry f args r =>
f -> Arguments args -> r
uncurryG (a -> f
f t
a) Arguments ts
args
  {-# INLINE uncurryG #-}

-- | A helper typeclass to convert a function that takes 'Arguments' as input
-- into a normal curried function. This is the inverse of 'UnCurry'.
--
-- For instance, if we have a function @f :: Arguments '[Int, Float] -> Double@, then
-- it will be converted to @f' :: Int -> Float -> Double@.
class Curry (args :: [Type]) (r :: Type) (f :: Type) | args r -> f where
  curryG :: (Arguments args -> r) -> f

instance Curry '[] r r where
  curryG :: (Arguments '[] -> r) -> r
curryG Arguments '[] -> r
f = Arguments '[] -> r
f Arguments '[]
Nil
  {-# INLINE curryG #-}

instance Curry args r f => Curry (a ': args) r (a -> f) where
  curryG :: (Arguments (a : args) -> r) -> a -> f
curryG Arguments (a : args) -> r
f a
a = forall (args :: [*]) r f.
Curry args r f =>
(Arguments args -> r) -> f
curryG (\Arguments args
args -> Arguments (a : args) -> r
f (a
a forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments args
args))