{-# LANGUAGE TypeFamilies #-}
{- |
Alternative to 'LLVM.Core.defineFunction'
that creates the final 'LLVM.Core.ret' instruction for you.
-}
module LLVM.Extra.Function (
   C,
   CodeGen,
   define,
   create,
   createNamed,
   Return, Result, ret,
   ) where

import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM

import Foreign.StablePtr (StablePtr)
import Foreign.Ptr (Ptr, FunPtr)

import Control.Applicative ((<*>))

import Data.Int (Int8, Int16, Int32, Int64)
import Data.Word (Word8, Word16, Word32, Word64, Word)


define ::
   (C f) => LLVM.Function f -> CodeGen f -> LLVM.CodeGenModule ()
define :: forall f. C f => Function f -> CodeGen f -> CodeGenModule ()
define Function f
fn CodeGen f
body =
   Function f -> FunctionCodeGen f -> CodeGenModule ()
forall f.
FunctionArgs f =>
Function f -> FunctionCodeGen f -> CodeGenModule ()
LLVM.defineFunction Function f
fn (Proxy f -> CodeGen f -> FunctionCodeGen f
forall f. C f => Proxy f -> CodeGen f -> FunctionCodeGen f
addRet (Function f -> Proxy f
forall (f :: * -> *) (g :: * -> *) a. f (g a) -> Proxy a
proxyFromElement2 Function f
fn) CodeGen f
body)

proxyFromElement2 :: f (g a) -> LP.Proxy a
proxyFromElement2 :: forall (f :: * -> *) (g :: * -> *) a. f (g a) -> Proxy a
proxyFromElement2 f (g a)
_ = Proxy a
forall a. Proxy a
LP.Proxy


create ::
   (C f) =>
   LLVM.Linkage -> CodeGen f -> LLVM.CodeGenModule (LLVM.Function f)
create :: forall f. C f => Linkage -> CodeGen f -> CodeGenModule (Function f)
create Linkage
linkage CodeGen f
body = do
   Function f
f <- Linkage -> CodeGenModule (Function f)
forall a. IsFunction a => Linkage -> CodeGenModule (Function a)
LLVM.newFunction Linkage
linkage
   Function f -> CodeGen f -> CodeGenModule ()
forall f. C f => Function f -> CodeGen f -> CodeGenModule ()
define Function f
f CodeGen f
body
   Function f -> CodeGenModule (Function f)
forall a. a -> CodeGenModule a
forall (m :: * -> *) a. Monad m => a -> m a
return Function f
f

createNamed ::
   (C f) =>
   LLVM.Linkage -> String -> CodeGen f -> LLVM.CodeGenModule (LLVM.Function f)
createNamed :: forall f.
C f =>
Linkage -> String -> CodeGen f -> CodeGenModule (Function f)
createNamed Linkage
linkage String
name CodeGen f
body = do
   Function f
f <- Linkage -> String -> CodeGenModule (Function f)
forall a.
IsFunction a =>
Linkage -> String -> CodeGenModule (Function a)
LLVM.newNamedFunction Linkage
linkage String
name
   Function f -> CodeGen f -> CodeGenModule ()
forall f. C f => Function f -> CodeGen f -> CodeGenModule ()
define Function f
f CodeGen f
body
   Function f -> CodeGenModule (Function f)
forall a. a -> CodeGenModule a
forall (m :: * -> *) a. Monad m => a -> m a
return Function f
f


{- |
> CodeGen (a->b->...-> IO z) =
>    Value a -> Value b -> ... CodeGenFunction r (Value z)@.
-}
class LLVM.FunctionArgs f => C f where
   type CodeGen f
   addRet :: LP.Proxy f -> CodeGen f -> LLVM.FunctionCodeGen f

instance (C b, LLVM.IsFirstClass a) => C (a -> b) where
   type CodeGen (a -> b) = LLVM.Value a -> CodeGen b
   addRet :: Proxy (a -> b) -> CodeGen (a -> b) -> FunctionCodeGen (a -> b)
addRet Proxy (a -> b)
proxy CodeGen (a -> b)
f Value a
a = Proxy b -> CodeGen b -> FunctionCodeGen b
forall f. C f => Proxy f -> CodeGen f -> FunctionCodeGen f
addRet (Proxy (a -> b)
proxyProxy (a -> b) -> Proxy a -> Proxy b
forall a b. Proxy (a -> b) -> Proxy a -> Proxy b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>Proxy a
forall a. Proxy a
LP.Proxy) (CodeGen (a -> b)
Value a -> CodeGen b
f Value a
a)

instance Return a => C (IO a) where
   type CodeGen (IO a) = LLVM.CodeGenFunction a (Result a)
   addRet :: Proxy (IO a) -> CodeGen (IO a) -> FunctionCodeGen (IO a)
addRet Proxy (IO a)
LP.Proxy CodeGen (IO a)
code = CodeGenFunction a (Result a)
CodeGen (IO a)
code CodeGenFunction a (Result a)
-> (Result a -> CodeGenFunction a ()) -> CodeGenFunction a ()
forall a b.
CodeGenFunction a a
-> (a -> CodeGenFunction a b) -> CodeGenFunction a b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result a -> CodeGenFunction a ()
forall a. Return a => Result a -> CodeGenFunction a ()
ret


class (LLVM.IsFirstClass a) => Return a where
   type Result a
   ret :: Result a -> LLVM.CodeGenFunction a ()
instance Return () where
   type Result () = ()
   ret :: Result () -> CodeGenFunction () ()
ret = () -> CodeGenFunction (Result ()) ()
Result () -> CodeGenFunction () ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret

instance Return Bool where
   type Result Bool = LLVM.Value Bool; ret :: Result Bool -> CodeGenFunction Bool ()
ret = Value Bool -> CodeGenFunction (Result (Value Bool)) ()
Result Bool -> CodeGenFunction Bool ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Int where
   type Result Int = LLVM.Value Int; ret :: Result Int -> CodeGenFunction Int ()
ret = Value Int -> CodeGenFunction (Result (Value Int)) ()
Result Int -> CodeGenFunction Int ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Int8 where
   type Result Int8 = LLVM.Value Int8; ret :: Result Int8 -> CodeGenFunction Int8 ()
ret = Value Int8 -> CodeGenFunction (Result (Value Int8)) ()
Result Int8 -> CodeGenFunction Int8 ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Int16 where
   type Result Int16 = LLVM.Value Int16; ret :: Result Int16 -> CodeGenFunction Int16 ()
ret = Value Int16 -> CodeGenFunction (Result (Value Int16)) ()
Result Int16 -> CodeGenFunction Int16 ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Int32 where
   type Result Int32 = LLVM.Value Int32; ret :: Result Int32 -> CodeGenFunction Int32 ()
ret = Value Int32 -> CodeGenFunction (Result (Value Int32)) ()
Result Int32 -> CodeGenFunction Int32 ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Int64 where
   type Result Int64 = LLVM.Value Int64; ret :: Result Int64 -> CodeGenFunction Int64 ()
ret = Value Int64 -> CodeGenFunction (Result (Value Int64)) ()
Result Int64 -> CodeGenFunction Int64 ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Word where
   type Result Word = LLVM.Value Word; ret :: Result Word -> CodeGenFunction Word ()
ret = Value Word -> CodeGenFunction (Result (Value Word)) ()
Result Word -> CodeGenFunction Word ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Word8 where
   type Result Word8 = LLVM.Value Word8; ret :: Result Word8 -> CodeGenFunction Word8 ()
ret = Value Word8 -> CodeGenFunction (Result (Value Word8)) ()
Result Word8 -> CodeGenFunction Word8 ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Word16 where
   type Result Word16 = LLVM.Value Word16; ret :: Result Word16 -> CodeGenFunction Word16 ()
ret = Value Word16 -> CodeGenFunction (Result (Value Word16)) ()
Result Word16 -> CodeGenFunction Word16 ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Word32 where
   type Result Word32 = LLVM.Value Word32; ret :: Result Word32 -> CodeGenFunction Word32 ()
ret = Value Word32 -> CodeGenFunction (Result (Value Word32)) ()
Result Word32 -> CodeGenFunction Word32 ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Word64 where
   type Result Word64 = LLVM.Value Word64; ret :: Result Word64 -> CodeGenFunction Word64 ()
ret = Value Word64 -> CodeGenFunction (Result (Value Word64)) ()
Result Word64 -> CodeGenFunction Word64 ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret

instance Return Float where
   type Result Float = LLVM.Value Float; ret :: Result Float -> CodeGenFunction Float ()
ret = Value Float -> CodeGenFunction (Result (Value Float)) ()
Result Float -> CodeGenFunction Float ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return Double where
   type Result Double = LLVM.Value Double; ret :: Result Double -> CodeGenFunction Double ()
ret = Value Double -> CodeGenFunction (Result (Value Double)) ()
Result Double -> CodeGenFunction Double ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret

instance Return (Ptr a) where
   type Result (Ptr a) = LLVM.Value (Ptr a); ret :: Result (Ptr a) -> CodeGenFunction (Ptr a) ()
ret = Value (Ptr a) -> CodeGenFunction (Result (Value (Ptr a))) ()
Result (Ptr a) -> CodeGenFunction (Ptr a) ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance (LLVM.IsType a) => Return (LLVM.Ptr a) where
   type Result (LLVM.Ptr a) = LLVM.Value (LLVM.Ptr a); ret :: Result (Ptr a) -> CodeGenFunction (Ptr a) ()
ret = Value (Ptr a) -> CodeGenFunction (Result (Value (Ptr a))) ()
Result (Ptr a) -> CodeGenFunction (Ptr a) ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance (LLVM.IsFunction a) => Return (FunPtr a) where
   type Result (FunPtr a) = LLVM.Value (FunPtr a); ret :: Result (FunPtr a) -> CodeGenFunction (FunPtr a) ()
ret = Value (FunPtr a) -> CodeGenFunction (Result (Value (FunPtr a))) ()
Result (FunPtr a) -> CodeGenFunction (FunPtr a) ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret
instance Return (StablePtr a) where
   type Result (StablePtr a) = LLVM.Value (StablePtr a); ret :: Result (StablePtr a) -> CodeGenFunction (StablePtr a) ()
ret = Value (StablePtr a)
-> CodeGenFunction (Result (Value (StablePtr a))) ()
Result (StablePtr a) -> CodeGenFunction (StablePtr a) ()
forall a. Ret a => a -> CodeGenFunction (Result a) ()
LLVM.ret