{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{- |
Useful control structures additionally to those in "LLVM.Util.Loop".
-}
module LLVM.Extra.Control (
   arrayLoop,
   arrayLoopWithExit,
   arrayLoop2WithExit,
   whileLoop,
   ifThenElse,
   ifThen,
   Select(select),
   selectTraversable,
   ifThenSelect,
   ) where

import LLVM.Extra.Arithmetic
   (icmp, sub, dec, advanceArrayElementPtr, )
import qualified LLVM.Core as LLVM
import LLVM.Core
   (getCurrentBasicBlock, newBasicBlock, defineBasicBlock,
    br, condBr,
    Ptr, Value, value,
    phi, addPhiInputs,
    IntPredicate(IntNE), CmpRet,
    IsInteger, IsType, IsConst, IsFirstClass,
    CodeGenFunction,
    CodeGenModule, newModule, defineModule, writeBitcodeToFile, )
import LLVM.Util.Loop (Phi, phis, addPhis, )

import qualified Control.Applicative as App
import qualified Data.Traversable as Trav
import Control.Monad (liftM3, liftM2, )

import Data.Tuple.HT (mapSnd, )



-- * control structures

{-
I had to export Phi's methods in llvm-0.6.8
in order to be able to implement this function.
-}
arrayLoop ::
   (Phi a, IsType b,
    Num i, IsConst i, IsInteger i, IsFirstClass i, CmpRet i Bool) =>
   Value i -> Value (Ptr b) -> a ->
   (Value (Ptr b) -> a -> CodeGenFunction r a) ->
   CodeGenFunction r a
arrayLoop len ptr start loopBody = do
   top <- getCurrentBasicBlock
   loop <- newBasicBlock
   body <- newBasicBlock
   exit <- newBasicBlock

   br loop

   defineBasicBlock loop
   i <- phi [(len, top)]
   p <- phi [(ptr, top)]
   vars <- phis top start
   t <- icmp IntNE i (value LLVM.zero)
   condBr t body exit

   defineBasicBlock body

   vars' <- loopBody p vars
   i' <- dec i
   p' <- advanceArrayElementPtr p

   body' <- getCurrentBasicBlock
   addPhis body' vars vars'
   addPhiInputs i [(i', body')]
   addPhiInputs p [(p', body')]
   br loop

   defineBasicBlock exit
   return vars


arrayLoopWithExit ::
   (Phi s, IsType a,
    Num i, IsConst i, IsInteger i, IsFirstClass i, CmpRet i Bool) =>
   Value i -> Value (Ptr a) -> s ->
   (Value (Ptr a) -> s -> CodeGenFunction r (Value Bool, s)) ->
   CodeGenFunction r (Value i, s)
arrayLoopWithExit len ptr start loopBody = do
   top <- getCurrentBasicBlock
   loop <- newBasicBlock
   body <- newBasicBlock
   next <- newBasicBlock
   exit <- newBasicBlock

   br loop

   defineBasicBlock loop
   i <- phi [(len, top)]
   p <- phi [(ptr, top)]
   vars <- phis top start
   t <- icmp IntNE i (value LLVM.zero)
   condBr t body exit

   defineBasicBlock body
   (cont, vars') <- loopBody p vars
   addPhis next vars vars'
   condBr cont next exit

   defineBasicBlock next
   i' <- dec i
   p' <- advanceArrayElementPtr p

   addPhiInputs i [(i', next)]
   addPhiInputs p [(p', next)]
   br loop

   defineBasicBlock exit
   pos <- sub len i
   return (pos, vars)


{- |
An alternative to 'arrayLoopWithExit'
where I try to persuade LLVM to use x86's LOOP instruction.
Unfortunately it becomes even worse.
LLVM developers say that x86 LOOP is actually slower
than manual decrement, zero test and conditional branch.
-}
_arrayLoopWithExitDecLoop ::
   (Phi a, IsType b,
    Num i, IsConst i, IsInteger i, IsFirstClass i, CmpRet i Bool) =>
   Value i -> Value (Ptr b) -> a ->
   (Value (Ptr b) -> a -> CodeGenFunction r (Value Bool, a)) ->
   CodeGenFunction r (Value i, a)
_arrayLoopWithExitDecLoop len ptr start loopBody = do
   top <- getCurrentBasicBlock
   checkEnd <- newBasicBlock
   loop <- newBasicBlock
   next <- newBasicBlock
   exit <- newBasicBlock

   {- unfortunately, t0 is not just stored as processor flag
      but is written to a register and then tested again in checkEnd -}
   t0 <- icmp IntNE len (value LLVM.zero)
   br checkEnd

   defineBasicBlock checkEnd
   i <- phi [(len, top)]
   p <- phi [(ptr, top)]
   vars <- phis top start
   t <- phi [(t0, top)]
   condBr t loop exit

   defineBasicBlock loop

   (cont, vars') <- loopBody p vars
   addPhis next vars vars'
   condBr cont next exit

   defineBasicBlock next
   p' <- advanceArrayElementPtr p
   i' <- dec i
   t' <- icmp IntNE i' (value LLVM.zero)

   addPhiInputs i [(i', next)]
   addPhiInputs p [(p', next)]
   addPhiInputs t [(t', next)]
   br checkEnd

   defineBasicBlock exit
   pos <- sub len i
   return (pos, vars)


arrayLoop2WithExit ::
   (Phi s, IsType a, IsType b,
    Num i, IsConst i, IsInteger i, IsFirstClass i, CmpRet i Bool) =>
   Value i -> Value (Ptr a) -> Value (Ptr b) -> s ->
   (Value (Ptr a) -> Value (Ptr b) -> s -> CodeGenFunction r (Value Bool, s)) ->
   CodeGenFunction r (Value i, s)
arrayLoop2WithExit len ptrA ptrB start loopBody =
   fmap (mapSnd snd) $
   arrayLoopWithExit len ptrA (ptrB,start)
      (\ptrAi (ptrBi,s0) -> do
         (cont, s1) <- loopBody ptrAi ptrBi s0
         ptrBi' <- advanceArrayElementPtr ptrBi
         return (cont, (ptrBi',s1)))


whileLoop ::
   Phi a =>
   a ->
   (a -> CodeGenFunction r (Value Bool)) ->
   (a -> CodeGenFunction r a) ->
   CodeGenFunction r a
whileLoop start check body = do
   top <- getCurrentBasicBlock
   loop <- newBasicBlock
   cont <- newBasicBlock
   exit <- newBasicBlock
   br loop

   defineBasicBlock loop
   state <- phis top start
   b <- check state
   condBr b cont exit
   defineBasicBlock cont
   res <- body state
   cont' <- getCurrentBasicBlock
   addPhis cont' state res
   br loop

   defineBasicBlock exit
   return state


{- |
This construct starts new blocks,
so be prepared when continueing after an 'ifThenElse'.
-}
ifThenElse ::
   Phi a =>
   Value Bool ->
   CodeGenFunction r a ->
   CodeGenFunction r a ->
   CodeGenFunction r a
ifThenElse cond thenCode elseCode = do
   thenBlock <- newBasicBlock
   elseBlock <- newBasicBlock
   mergeBlock <- newBasicBlock
   condBr cond thenBlock elseBlock

   defineBasicBlock thenBlock
   a0 <- thenCode
   thenBlock' <- getCurrentBasicBlock
   br mergeBlock

   defineBasicBlock elseBlock
   a1 <- elseCode
   elseBlock' <- getCurrentBasicBlock
   br mergeBlock

   defineBasicBlock mergeBlock
   a2 <- phis thenBlock' a0
   addPhis elseBlock' a2 a1
   return a2


ifThen ::
   Phi a =>
   Value Bool ->
   a ->
   CodeGenFunction r a ->
   CodeGenFunction r a
ifThen cond deflt thenCode = do
   defltBlock <- getCurrentBasicBlock
   thenBlock <- newBasicBlock
   mergeBlock <- newBasicBlock
   condBr cond thenBlock mergeBlock

   defineBasicBlock thenBlock
   a0 <- thenCode
   thenBlock' <- getCurrentBasicBlock
   br mergeBlock

   defineBasicBlock mergeBlock
   a1 <- phis defltBlock deflt
   addPhis thenBlock' a1 a0
   return a1


class Phi a => Select a where
   select :: Value Bool -> a -> a -> CodeGenFunction r a

instance (IsFirstClass a, CmpRet a Bool) => Select (Value a) where
   select = LLVM.select

instance Select () where
   select _ () () = return ()

instance (Select a, Select b) => Select (a,b) where
   select cond (a0,b0) (a1,b1) =
      liftM2 (,)
         (select cond a0 a1)
         (select cond b0 b1)

instance (Select a, Select b, Select c) => Select (a,b,c) where
   select cond (a0,b0,c0) (a1,b1,c1) =
      liftM3 (,,)
         (select cond a0 a1)
         (select cond b0 b1)
         (select cond c0 c1)

selectTraversable ::
   (Select a, Trav.Traversable f, App.Applicative f) =>
   Value Bool -> f a -> f a -> CodeGenFunction r (f a)
selectTraversable b x y =
   Trav.sequence (App.liftA2 (select b) x y)


{- |
Branch-free variant of 'ifThen'
that is faster if the enclosed block is very simply,
say, if it contains at most two instructions.
It can only be used as alternative to 'ifThen'
if the enclosed block is free of side effects.
-}
ifThenSelect ::
   Select a =>
   Value Bool ->
   a ->
   CodeGenFunction r a ->
   CodeGenFunction r a
ifThenSelect cond deflt thenCode = do
   thenResult <- thenCode
   select cond thenResult deflt


-- * debugging

_emitCode :: FilePath -> CodeGenModule a -> IO ()
_emitCode fileName cgm = do
   m <- newModule
   _ <- defineModule m cgm
   writeBitcodeToFile fileName m