module LLVM.Extra.Multi.Iterator (
   takeWhile,
   countDown,
   take,
   ) where

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Iterator as Iter

import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction)

import Control.Applicative (liftA2)

import Prelude hiding (take, takeWhile)



takeWhile ::
   (a -> CodeGenFunction r (MultiValue.T Bool)) ->
   Iter.T r a -> Iter.T r a
takeWhile p = Iter.takeWhile (fmap unpackBool . p)

unpackBool :: MultiValue.T Bool -> LLVM.Value Bool
unpackBool (MultiValue.Cons b) = b

countDown ::
   (MultiValue.Additive i, MultiValue.Comparison i,
    MultiValue.IntegerConstant i) =>
   MultiValue.T i -> Iter.T r (MultiValue.T i)
countDown len =
   takeWhile (MultiValue.cmp LLVM.CmpLT MultiValue.zero) $
   Iter.iterate MultiValue.dec len

take ::
   (MultiValue.Additive i, MultiValue.Comparison i,
    MultiValue.IntegerConstant i) =>
   MultiValue.T i -> Iter.T r a -> Iter.T r a
take len xs = liftA2 const xs (countDown len)