{-# LANGUAGE TypeFamilies #-}
{- |
Alternative to 'LLVM.Core.defineFunction'
that creates the final 'LLVM.Core.ret' instruction for you.
-}
module LLVM.Extra.Function (
   C,
   CodeGen,
   define,
   create,
   createNamed,
   Return, Result, ret,
   ) where

import qualified LLVM.Util.Proxy as LP
import qualified LLVM.Core as LLVM

import Foreign.StablePtr (StablePtr)
import Foreign.Ptr (Ptr, FunPtr)

import Control.Applicative ((<*>))

import Data.Int (Int8, Int16, Int32, Int64)
import Data.Word (Word8, Word16, Word32, Word64, Word)


define ::
   (C f) => LLVM.Function f -> CodeGen f -> LLVM.CodeGenModule ()
define fn body =
   LLVM.defineFunction fn (addRet (proxyFromElement2 fn) body)

proxyFromElement2 :: f (g a) -> LP.Proxy a
proxyFromElement2 _ = LP.Proxy


create ::
   (C f) =>
   LLVM.Linkage -> CodeGen f -> LLVM.CodeGenModule (LLVM.Function f)
create linkage body = do
   f <- LLVM.newFunction linkage
   define f body
   return f

createNamed ::
   (C f) =>
   LLVM.Linkage -> String -> CodeGen f -> LLVM.CodeGenModule (LLVM.Function f)
createNamed linkage name body = do
   f <- LLVM.newNamedFunction linkage name
   define f body
   return f


{- |
> CodeGen (a->b->...-> IO z) =
>    Value a -> Value b -> ... CodeGenFunction r (Value z)@.
-}
class LLVM.FunctionArgs f => C f where
   type CodeGen f :: *
   addRet :: LP.Proxy f -> CodeGen f -> LLVM.FunctionCodeGen f

instance (C b, LLVM.IsFirstClass a) => C (a -> b) where
   type CodeGen (a -> b) = LLVM.Value a -> CodeGen b
   addRet proxy f a = addRet (proxy<*>LP.Proxy) (f a)

instance Return a => C (IO a) where
   type CodeGen (IO a) = LLVM.CodeGenFunction a (Result a)
   addRet LP.Proxy code = code >>= ret


class (LLVM.IsFirstClass a) => Return a where
   type Result a
   ret :: Result a -> LLVM.CodeGenFunction a ()
instance Return () where
   type Result () = ()
   ret = LLVM.ret

instance Return Bool where
   type Result Bool = LLVM.Value Bool; ret = LLVM.ret
instance Return Int where
   type Result Int = LLVM.Value Int; ret = LLVM.ret
instance Return Int8 where
   type Result Int8 = LLVM.Value Int8; ret = LLVM.ret
instance Return Int16 where
   type Result Int16 = LLVM.Value Int16; ret = LLVM.ret
instance Return Int32 where
   type Result Int32 = LLVM.Value Int32; ret = LLVM.ret
instance Return Int64 where
   type Result Int64 = LLVM.Value Int64; ret = LLVM.ret
instance Return Word where
   type Result Word = LLVM.Value Word; ret = LLVM.ret
instance Return Word8 where
   type Result Word8 = LLVM.Value Word8; ret = LLVM.ret
instance Return Word16 where
   type Result Word16 = LLVM.Value Word16; ret = LLVM.ret
instance Return Word32 where
   type Result Word32 = LLVM.Value Word32; ret = LLVM.ret
instance Return Word64 where
   type Result Word64 = LLVM.Value Word64; ret = LLVM.ret

instance Return Float where
   type Result Float = LLVM.Value Float; ret = LLVM.ret
instance Return Double where
   type Result Double = LLVM.Value Double; ret = LLVM.ret

instance Return (Ptr a) where
   type Result (Ptr a) = LLVM.Value (Ptr a); ret = LLVM.ret
instance (LLVM.IsType a) => Return (LLVM.Ptr a) where
   type Result (LLVM.Ptr a) = LLVM.Value (LLVM.Ptr a); ret = LLVM.ret
instance (LLVM.IsFunction a) => Return (FunPtr a) where
   type Result (FunPtr a) = LLVM.Value (FunPtr a); ret = LLVM.ret
instance Return (StablePtr a) where
   type Result (StablePtr a) = LLVM.Value (StablePtr a); ret = LLVM.ret