{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
module LLVM.Extra.Iterator (
   T,
   -- * consumers
   mapM_,
   mapState_,
   mapStateM_,
   mapWhileState_,
   -- * producers
   empty,
   singleton,
   cons,
   iterate,
   countDown,
   arrayPtrs,
   storableArrayPtrs,
   -- * modifiers
   mapM,
   mapMaybe,
   catMaybes,
   takeWhileJust,
   takeWhile,
   cartesian,
   take,
   -- * application examples
   fixedLengthLoop,
   arrayLoop,
   arrayLoopWithExit,
   arrayLoop2,
   ) where

import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.ArithmeticPrivate as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Control as C
import qualified LLVM.Core as LLVM
import LLVM.Core
   (CodeGenFunction, Value, value, valueOf,
    CmpRet, IsInteger, IsType, IsConst, IsPrimitive)

import Foreign.Ptr (Ptr, )

import qualified Control.Monad.Trans.State as MS
import qualified Control.Applicative as App
import qualified Control.Functor.HT as FuncHT
import Control.Monad (void, (<=<), )
import Control.Applicative (Applicative, liftA2, (<$>), (<$), )

import Data.Tuple.HT (mapFst, mapSnd, )

import Prelude2010 hiding (iterate, takeWhile, take, mapM, mapM_)
import Prelude ()


{- |
Simulates a non-strict list.
-}
data T r a =
   forall s. (Tuple.Phi s, Tuple.Undefined s) =>
   Cons s (forall z. (Tuple.Phi z) => s -> MaybeCont.T r z (a,s))

mapM_ :: (a -> CodeGenFunction r ()) -> T r a -> CodeGenFunction r ()
mapM_ f (Cons s next) =
   void $
   C.loopWithExit s
      (\s0 ->
         MaybeCont.resolve (next s0)
            (return (valueOf False, s0))
            (\(a,s1) -> (valueOf True, s1) <$ f a))
      return

mapState_ ::
   (Tuple.Phi t) =>
   (a -> t -> CodeGenFunction r t) ->
   T r a -> t -> CodeGenFunction r t
mapState_ f (Cons s next) t =
   snd <$>
   C.loopWithExit (s,t)
      (\(s0,t0) ->
         MaybeCont.resolve (next s0)
            (return (valueOf False, (s0,t0)))
            (\(a,s1) -> (\t1 -> (valueOf True, (s1,t1))) <$> f a t0))
      return

mapStateM_ ::
   (Tuple.Phi t) =>
   (a -> MS.StateT t (CodeGenFunction r) ()) ->
   T r a -> MS.StateT t (CodeGenFunction r) ()
mapStateM_ f xs =
   MS.StateT $ \t ->
      (,) () <$> mapState_ (\a t0 -> snd <$> MS.runStateT (f a) t0) xs t


mapWhileState_ ::
   (Tuple.Phi t) =>
   (a -> t -> CodeGenFunction r (Value Bool, t)) ->
   T r a -> t -> CodeGenFunction r t
mapWhileState_ f (Cons s next) t =
   snd <$>
   C.loopWithExit (s,t)
      (\(s0,t0) ->
         MaybeCont.resolve (next s0)
            (return (valueOf False, (s0,t0)))
            (\(a,s1) -> (\(b,t1) -> (b, (s1,t1))) <$> f a t0))
      return


empty :: T r a
empty = Cons () (\() -> MaybeCont.nothing)

singleton :: a -> T r a
singleton a =
   Cons
      (valueOf True)
      (\running -> MaybeCont.guard running >> return (a, valueOf False))

cons :: (Tuple.Phi a, Tuple.Undefined a) => a -> T r a -> T r a
cons a0 (Cons s next) =
   Cons Maybe.nothing
      (fmap (mapSnd Maybe.just) .
       MaybeCont.fromMaybe .
       (\ms -> Maybe.run ms
         (return $ Maybe.just (a0,s))
         (MaybeCont.toMaybe . next)))


instance Functor (T r) where
   fmap f (Cons s next) = Cons s (\s0 -> mapFst f <$> next s0)

{- |
@ZipList@ semantics
-}
instance Applicative (T r) where
   pure a = Cons () (\() -> return (a,()))
   Cons fs fnext <*> Cons as anext =
      Cons (fs,as)
         (\(fs0,as0) -> do
            (f,fs1) <- fnext fs0
            (a,as1) <- anext as0
            return (f a, (fs1,as1)))


{-
On the one hand,
I did not want to name it @map@ because it differs from @fmap@.
On the other hand, @mapM@ does not fit very well
because the result is not in the CodeGenFunction monad.
-}
mapM :: (a -> CodeGenFunction r b) -> T r a -> T r b
mapM f (Cons s next) = Cons s (MaybeCont.lift . FuncHT.mapFst f <=< next)

mapMaybe ::
   (Tuple.Phi b, Tuple.Undefined b) =>
   (a -> CodeGenFunction r (Maybe.T b)) -> T r a -> T r b
mapMaybe f = catMaybes . mapM f

catMaybes :: (Tuple.Phi a, Tuple.Undefined a) => T r (Maybe.T a) -> T r a
catMaybes (Cons s next) =
   Cons s
      (\s0 ->
         MaybeCont.fromMaybe $
         fmap (\(ma,s2) -> fmap (flip (,) s2) ma) $
         C.loopWithExit s0
            (\s1 ->
               MaybeCont.resolve (next s1)
                  (return (valueOf False, (Maybe.nothing, s1)))
                  (\(ma,s2) ->
                     Maybe.run ma
                        (return (valueOf True, (Maybe.nothing, s2)))
                        (\a -> return (valueOf False, (Maybe.just a, s2)))))
            (return . snd))

takeWhileJust :: T r (Maybe.T a) -> T r a
takeWhileJust (Cons s next) =
   Cons s (FuncHT.mapFst MaybeCont.fromPlainMaybe <=< next)

takeWhile :: (a -> CodeGenFunction r (Value Bool)) -> T r a -> T r a
takeWhile p = takeWhileJust . mapM (\a -> flip Maybe.fromBool a <$> p a)

{- |
Attention:
This always performs one function call more than necessary.
I.e. if 'f' reads from or writes to memory
make sure that accessing one more pointer is legal.
-}
iterate ::
   (Tuple.Phi a, Tuple.Undefined a) => (a -> CodeGenFunction r a) -> a -> T r a
iterate f a = Cons a (\a0 -> MaybeCont.lift $ fmap ((,) a0) $ f a0)


cartesianAux ::
   (Tuple.Phi a, Tuple.Phi b, Tuple.Undefined a, Tuple.Undefined b) =>
   T r a -> T r b -> T r (Maybe.T (a,b))
cartesianAux (Cons sa nextA) (Cons sb nextB) =
   Cons (Maybe.nothing,sa,sb)
      (\(ma0,sa0,sb0) -> do
         (a1,sa1) <-
            MaybeCont.alternative
               (MaybeCont.fromMaybe $ return $ fmap (flip (,) sa0) ma0)
               (nextA sa0)
         MaybeCont.lift $
            MaybeCont.resolve (nextB sb0)
               (return (Maybe.nothing,(Maybe.nothing,sa1,sb)))
               (\(b1,sb1) ->
                  return (Maybe.just (a1,b1), (Maybe.just a1, sa1, sb1))))

cartesian ::
   (Tuple.Phi a, Tuple.Phi b, Tuple.Undefined a, Tuple.Undefined b) =>
   T r a -> T r b -> T r (a,b)
cartesian as bs = catMaybes $ cartesianAux as bs

countDown ::
   (Num i, IsConst i, IsInteger i, CmpRet i, IsPrimitive i) =>
   Value i -> T r (Value i)
countDown len =
   takeWhile (A.cmp LLVM.CmpLT (value LLVM.zero)) $ iterate A.dec len

take ::
   (Num i, IsConst i, IsInteger i, CmpRet i, IsPrimitive i) =>
   Value i -> T r a -> T r a
take len xs = liftA2 const xs (countDown len)

arrayPtrs :: (IsType a) => Value (LLVM.Ptr a) -> T r (Value (LLVM.Ptr a))
arrayPtrs = iterate A.advanceArrayElementPtr

storableArrayPtrs :: (Storable.C a) => Value (Ptr a) -> T r (Value (Ptr a))
storableArrayPtrs = iterate Storable.incrementPtr


-- * examples

fixedLengthLoop ::
   (Tuple.Phi s, Num i, IsConst i, IsInteger i, CmpRet i, IsPrimitive i) =>
   Value i -> s ->
   (s -> CodeGenFunction r s) ->
   CodeGenFunction r s
fixedLengthLoop len start loopBody =
   mapState_ (const loopBody) (countDown len) start

arrayLoop ::
   (Tuple.Phi a, IsType b, Num i, IsConst i, IsInteger i, CmpRet i, IsPrimitive i) =>
   Value i -> Value (LLVM.Ptr b) -> a ->
   (Value (LLVM.Ptr b) -> a -> CodeGenFunction r a) ->
   CodeGenFunction r a
arrayLoop len ptr start loopBody =
   mapState_ loopBody (take len $ arrayPtrs ptr) start

arrayLoopWithExit ::
   (Tuple.Phi s, IsType a, Num i, IsConst i, IsInteger i, CmpRet i, IsPrimitive i) =>
   Value i -> Value (LLVM.Ptr a) -> s ->
   (Value (LLVM.Ptr a) -> s -> CodeGenFunction r (Value Bool, s)) ->
   CodeGenFunction r (Value i, s)
arrayLoopWithExit len ptr0 start loopBody = do
   (i, end) <-
      mapWhileState_
         (\(i,ptr) (_i,s) -> mapSnd ((,) i) <$> loopBody ptr s)
         (liftA2 (,) (countDown len) (arrayPtrs ptr0))
         (len,start)
   pos <- A.sub len i
   return (pos, end)

arrayLoop2 ::
   (Tuple.Phi s, IsType a, IsType b, Num i, IsConst i, IsInteger i, CmpRet i, IsPrimitive i) =>
   Value i -> Value (LLVM.Ptr a) -> Value (LLVM.Ptr b) -> s ->
   (Value (LLVM.Ptr a) -> Value (LLVM.Ptr b) -> s -> CodeGenFunction r s) ->
   CodeGenFunction r s
arrayLoop2 len ptrA ptrB start loopBody =
   mapState_ (uncurry loopBody)
      (take len $ liftA2 (,) (arrayPtrs ptrA) (arrayPtrs ptrB)) start