{-# LANGUAGE TypeFamilies #-}
{- |
Useful control structures additionally to those in "LLVM.Util.Loop".
-}
module LLVM.Extra.Control (
   arrayLoop,
   arrayLoop2,
   arrayLoopWithExit,
   arrayLoop2WithExit,
   fixedLengthLoop,
   whileLoop,
   whileLoopShared,
   loopWithExit,
   ifThenElse,
   ifThen,
   Select(select),
   selectTraversable,
   ifThenSelect,
   ) where

import LLVM.Extra.ArithmeticPrivate
   (cmp, sub, dec, advanceArrayElementPtr, )
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, phis, addPhis, )
import LLVM.Core
   (getCurrentBasicBlock, newBasicBlock, defineBasicBlock,
    br, condBr,
    Value, value, valueOf,
    phi, addPhiInputs,
    CmpPredicate(CmpGT), CmpRet, CmpResult,
    IsInteger, IsType, IsConst, IsFirstClass,
    CodeGenFunction,
    CodeGenModule, newModule, defineModule, writeBitcodeToFile, )

import Foreign.Ptr (Ptr, )

import qualified Control.Applicative as App
import qualified Data.Traversable as Trav
import Control.Monad (liftM3, liftM2, )

import Data.Tuple.HT (mapSnd, )



-- * control structures

{-
I had to export Phi's methods in llvm-0.6.8
in order to be able to implement this function.
-}
arrayLoop ::
   (Phi a, IsType b,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> Value (Ptr b) -> a ->
   (Value (Ptr b) -> a -> CodeGenFunction r a) ->
   CodeGenFunction r a
arrayLoop len ptr start loopBody =
   fmap snd $
   fixedLengthLoop len (ptr, start) $ \(p,s) ->
      liftM2 (,)
         (advanceArrayElementPtr p)
         (loopBody p s)

arrayLoop2 ::
   (Phi s, IsType a, IsType b,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> Value (Ptr a) -> Value (Ptr b) -> s ->
   (Value (Ptr a) -> Value (Ptr b) -> s -> CodeGenFunction r s) ->
   CodeGenFunction r s
arrayLoop2 len ptrA ptrB start loopBody =
   fmap snd $
   arrayLoop len ptrA (ptrB,start)
      (\pa (pb,s) ->
         liftM2 (,)
            (advanceArrayElementPtr pb)
            (loopBody pa pb s))


arrayLoopWithExit ::
   (Phi s, IsType a,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> Value (Ptr a) -> s ->
   (Value (Ptr a) -> s -> CodeGenFunction r (Value Bool, s)) ->
   CodeGenFunction r (Value i, s)
arrayLoopWithExit len ptr start loopBody = do
   ((_, vars), (i,_)) <-
      whileLoopShared ((valueOf True, start), (len, ptr)) $ \((b,v0), (i,p)) ->
         (A.and b =<< cmp CmpGT i (value LLVM.zero),
          do bv1 <- loopBody p v0
             ip1 <-
                ifThen (fst bv1) (i,p) $
                   liftM2 (,)
                      (dec i)
                      (advanceArrayElementPtr p)
             return (bv1,ip1))
   pos <- sub len i
   return (pos, vars)


{- |
An alternative to 'arrayLoopWithExit'
where I try to persuade LLVM to use x86's LOOP instruction.
Unfortunately it becomes even worse.
LLVM developers say that x86 LOOP is actually slower
than manual decrement, zero test and conditional branch.
-}
_arrayLoopWithExitDecLoop ::
   (Phi a, IsType b,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> Value (Ptr b) -> a ->
   (Value (Ptr b) -> a -> CodeGenFunction r (Value Bool, a)) ->
   CodeGenFunction r (Value i, a)
_arrayLoopWithExitDecLoop len ptr start loopBody = do
   top <- getCurrentBasicBlock
   checkEnd <- newBasicBlock
   loop <- newBasicBlock
   next <- newBasicBlock
   exit <- newBasicBlock

   {- unfortunately, t0 is not just stored as processor flag
      but is written to a register and then tested again in checkEnd -}
   t0 <- cmp CmpGT len (value LLVM.zero)
   br checkEnd

   defineBasicBlock checkEnd
   i <- phi [(len, top)]
   p <- phi [(ptr, top)]
   vars <- phis top start
   t <- phi [(t0, top)]
   condBr t loop exit

   defineBasicBlock loop

   (cont, vars') <- loopBody p vars
   addPhis next vars vars'
   condBr cont next exit

   defineBasicBlock next
   p' <- advanceArrayElementPtr p
   i' <- dec i
   t' <- cmp CmpGT i' (value LLVM.zero)

   addPhiInputs i [(i', next)]
   addPhiInputs p [(p', next)]
   addPhiInputs t [(t', next)]
   br checkEnd

   defineBasicBlock exit
   pos <- sub len i
   return (pos, vars)


arrayLoop2WithExit ::
   (Phi s, IsType a, IsType b,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> Value (Ptr a) -> Value (Ptr b) -> s ->
   (Value (Ptr a) -> Value (Ptr b) -> s -> CodeGenFunction r (Value Bool, s)) ->
   CodeGenFunction r (Value i, s)
arrayLoop2WithExit len ptrA ptrB start loopBody =
   fmap (mapSnd snd) $
   arrayLoopWithExit len ptrA (ptrB,start)
      (\ptrAi (ptrB0,s0) -> do
         (cont, s1) <- loopBody ptrAi ptrB0 s0
         ptrB1 <- advanceArrayElementPtr ptrB0
         return (cont, (ptrB1,s1)))


fixedLengthLoop ::
   (Phi s,
    Num i, IsConst i, IsInteger i, CmpRet i, CmpResult i ~ Bool) =>
   Value i -> s ->
   (s -> CodeGenFunction r s) ->
   CodeGenFunction r s
fixedLengthLoop len start loopBody =
   fmap snd $
   whileLoopShared (len,start) $ \(i,s) ->
      (cmp LLVM.CmpGT i (value LLVM.zero),
       liftM2 (,) (dec i) (loopBody s))


whileLoop, _whileLoop ::
   Phi a =>
   a ->
   (a -> CodeGenFunction r (Value Bool)) ->
   (a -> CodeGenFunction r a) ->
   CodeGenFunction r a
whileLoop start check body =
   loopWithExit start
      (\a -> fmap (flip (,) a) $ check a)
      body

_whileLoop start check body = do
   top <- getCurrentBasicBlock
   loop <- newBasicBlock
   cont <- newBasicBlock
   exit <- newBasicBlock
   br loop

   defineBasicBlock loop
   state <- phis top start
   b <- check state
   condBr b cont exit
   defineBasicBlock cont
   res <- body state
   cont' <- getCurrentBasicBlock
   addPhis cont' state res
   br loop

   defineBasicBlock exit
   return state


loopWithExit ::
   Phi a =>
   a ->
   (a -> CodeGenFunction r (Value Bool, b)) ->
   (b -> CodeGenFunction r a) ->
   CodeGenFunction r b
loopWithExit start check body = do
   top <- getCurrentBasicBlock
   loop <- newBasicBlock
   cont <- newBasicBlock
   exit <- newBasicBlock
   br loop

   defineBasicBlock loop
   state <- phis top start
   (contB,b) <- check state
   condBr contB cont exit
   defineBasicBlock cont
   a <- body b
   cont' <- getCurrentBasicBlock
   addPhis cont' state a
   br loop

   defineBasicBlock exit
   return b


{- |
This is a variant of 'whileLoop' that may be more convient,
because you only need one lambda expression
for both loop condition and loop body.
-}
whileLoopShared ::
   Phi a =>
   a ->
   (a ->
      (CodeGenFunction r (Value Bool),
       CodeGenFunction r a)) ->
   CodeGenFunction r a
whileLoopShared start checkBody =
   whileLoop start
      (fst . checkBody)
      (snd . checkBody)

{- |
This construct starts new blocks,
so be prepared when continueing after an 'ifThenElse'.
-}
ifThenElse ::
   Phi a =>
   Value Bool ->
   CodeGenFunction r a ->
   CodeGenFunction r a ->
   CodeGenFunction r a
ifThenElse cond thenCode elseCode = do
   thenBlock <- newBasicBlock
   elseBlock <- newBasicBlock
   mergeBlock <- newBasicBlock
   condBr cond thenBlock elseBlock

   defineBasicBlock thenBlock
   a0 <- thenCode
   thenBlock' <- getCurrentBasicBlock
   br mergeBlock

   defineBasicBlock elseBlock
   a1 <- elseCode
   elseBlock' <- getCurrentBasicBlock
   br mergeBlock

   defineBasicBlock mergeBlock
   a2 <- phis thenBlock' a0
   addPhis elseBlock' a2 a1
   return a2


ifThen ::
   Phi a =>
   Value Bool ->
   a ->
   CodeGenFunction r a ->
   CodeGenFunction r a
ifThen cond deflt thenCode = do
   defltBlock <- getCurrentBasicBlock
   thenBlock <- newBasicBlock
   mergeBlock <- newBasicBlock
   condBr cond thenBlock mergeBlock

   defineBasicBlock thenBlock
   a0 <- thenCode
   thenBlock' <- getCurrentBasicBlock
   br mergeBlock

   defineBasicBlock mergeBlock
   a1 <- phis defltBlock deflt
   addPhis thenBlock' a1 a0
   return a1


class Phi a => Select a where
   select :: Value Bool -> a -> a -> CodeGenFunction r a

instance (IsFirstClass a, CmpRet a, CmpResult a ~ Bool) => Select (Value a) where
   select = LLVM.select

instance Select () where
   select _ () () = return ()

instance (Select a, Select b) => Select (a,b) where
   select cond (a0,b0) (a1,b1) =
      liftM2 (,)
         (select cond a0 a1)
         (select cond b0 b1)

instance (Select a, Select b, Select c) => Select (a,b,c) where
   select cond (a0,b0,c0) (a1,b1,c1) =
      liftM3 (,,)
         (select cond a0 a1)
         (select cond b0 b1)
         (select cond c0 c1)

selectTraversable ::
   (Select a, Trav.Traversable f, App.Applicative f) =>
   Value Bool -> f a -> f a -> CodeGenFunction r (f a)
selectTraversable b x y =
   Trav.sequence (App.liftA2 (select b) x y)


{- |
Branch-free variant of 'ifThen'
that is faster if the enclosed block is very simply,
say, if it contains at most two instructions.
It can only be used as alternative to 'ifThen'
if the enclosed block is free of side effects.
-}
ifThenSelect ::
   Select a =>
   Value Bool ->
   a ->
   CodeGenFunction r a ->
   CodeGenFunction r a
ifThenSelect cond deflt thenCode = do
   thenResult <- thenCode
   select cond thenResult deflt


-- * debugging

_emitCode :: FilePath -> CodeGenModule a -> IO ()
_emitCode fileName cgm = do
   m <- newModule
   _ <- defineModule m cgm
   writeBitcodeToFile fileName m