{-# LANGUAGE FlexibleContexts, BangPatterns #-}

-- | This is a specialized fusion framework designed for the particular needs
-- of this library.  In particular, it avoids duplicated indices, and has
-- special support for the 'snoc' operation.
-- 
-- It is also compatible with "Data.Vector.Fusion.Stream", being smart enough
-- to convert between the two styles as necessary.
-- 
-- Most methods in this module that output a vector return, in particular, a
-- 'VVector'.  If another vector implementation is required, the 'convert' function
-- will fuse away the intermediate 'VVector'.
module Data.RangeMin.Fusion where

import Control.Monad.ST
import Control.Monad.Primitive
import Data.RangeMin.Common.Types
import Data.RangeMin.Common.Vector
import Data.RangeMin.Common.Combinators
import qualified Data.Vector.Generic as G

import Data.RangeMin.Fusion.Stream (Stream)
import qualified Data.RangeMin.Fusion.Stream as S
import qualified Data.RangeMin.Fusion.Stream.Monadic as SM

import Prelude hiding (mapM_, map, replicate)

{-# INLINE convert #-}
convert :: (Vector v a, Vector v' a) => v a -> v' a
convert xs = unstream (stream xs)

{-# NOINLINE unstream #-}
unstream :: Vector v a => Stream a -> v a
unstream !str = create $ do
	!dest <- new (S.length str)
	fill dest str

{-# INLINE [1] stream #-}
stream :: Vector v a => v a -> Stream a
stream !src = S.generate (G.length src) (G.unsafeIndex src)

{-# INLINE unzip #-}
unzip :: (Vector v (a, b), Vector va a, Vector vb b) =>
	v (a, b) -> (va a, vb b)
unzip xs = runST $ do
	!destA <- new n
	!destB <- new n
	S.imapM_ (\ i (a, b) -> do
		write destA i a
		write destB i b) str
	liftM2 (,) (unsafeFreeze destA) (unsafeFreeze destB)
	where	str = stream xs
		n = S.length str

{-# INLINE unzip3 #-}
unzip3 :: (Vector v (a, b, c), Vector va a, Vector vb b, Vector vc c) =>
	v (a, b, c) -> (va a, vb b, vc c)
unzip3 xs = runST $ do
	!destA <- new n
	!destB <- new n
	!destC <- new n
	S.imapM_ (\ i (a, b, c) -> do
		write destA i a
		write destB i b
		write destC i c) str
	liftM3 (,,) (unsafeFreeze destA) (unsafeFreeze destB) (unsafeFreeze destC)
	where	str = stream xs
		n = S.length str

{-# INLINE enumN #-}
enumN :: Length -> VVector Int
enumN n = unstream (S.enumN n)

{-# INLINE generate #-}
generate :: Length -> (Index -> a) -> VVector a
generate n f = unstream (S.generate n f)

{-# INLINE imap #-}
imap :: (Vector v a) => (Index -> a -> a') -> v a -> VVector a'
imap f xs = unstream (S.imap f (stream xs))

{-# INLINE map #-}
map :: (Vector v a) => (a -> a') -> v a -> VVector a'
map f = imap (const f)

{-# INLINE imapAccumL #-}
imapAccumL :: (Vector v a) => (b -> Index -> a -> (c, b)) -> b -> v a -> VVector c
imapAccumL f z xs = unstream (S.imapAccumL f z (stream xs))

{-# INLINE imapM_ #-}
imapM_ :: (Monad m, Vector v a) => (Index -> a -> m b) -> v a -> m ()
imapM_ f xs = S.imapM_ f (stream xs)

{-# INLINE ipostscanl #-}
ipostscanl :: (Vector v a) => (b -> Index -> a -> b) -> b -> v a -> VVector b
ipostscanl f = imapAccumL (\ z i a -> let z' = f z i a in (z', z'))

{-# INLINE mapM_ #-}
mapM_ :: (Monad m, Vector v a) => (a -> m b) -> v a -> m ()
mapM_ f = imapM_ (const f)

{-# INLINE postscanl #-}
postscanl :: (Vector v a) => (b -> a -> b) -> b -> v a -> VVector b
postscanl f = ipostscanl (\ z _ -> f z)

{-# INLINE replicate #-}
replicate :: Int -> a -> VVector a
replicate n a = generate n (const a)

{-# INLINE snoc #-}
snoc :: Vector v a => v a -> a -> VVector a
xs `snoc` x = unstream (stream xs `S.snoc` x)

{-# INLINE snoc' #-}
snoc' :: Vector v a => v a -> a -> VVector a
xs `snoc'` x = x `seq` (xs `snoc` x)

{-# INLINE iunfoldN #-}
iunfoldN :: Length -> (Index -> b -> Maybe (a, b)) -> b -> VVector a
iunfoldN n f z = unstream (S.iunfoldN n f z)

{-# INLINE unfoldN #-}
unfoldN :: Int -> (b -> Maybe (a, b)) -> b -> VVector a
unfoldN n f = iunfoldN n (const f)

{-# INLINE ifoldl #-}
ifoldl :: Vector v a => (b -> Index -> a -> b) -> b -> v a -> b
ifoldl f z xs = S.ifoldl f z (stream xs)

{-# INLINE foldl #-}
foldl :: Vector v a => (b -> a -> b) -> b -> v a -> b
foldl f = ifoldl (\ z _ -> f z)

{-# INLINE fromListN #-}
fromListN :: Length -> [a] -> VVector a
fromListN n xs = unstream (S.fromListN n xs)

{-# INLINE [0] munstream #-}
munstream :: (PrimMonad m, Vector v a) => S.MStream m a -> m (v a)
munstream str = do
	let !n = SM.length str
	!dest <- new n
	_ <- fillM dest str
	unsafeFreeze (sliceM 0 n dest)

{-# INLINE [0] replicateM #-}
replicateM :: (PrimMonad m, Vector v a) => Length -> m a -> m (v a)
replicateM n m = munstream (SM.generateM n (const m))

{-# INLINE [0] fillM #-}
fillM :: (PrimMonad m, MVector v a) => v (PrimState m) a -> S.MStream m a -> m (v (PrimState m) a)
fillM !dest !str = do
	let !n = SM.length str
	SM.imapM_ (write dest) str
	return (sliceM 0 n dest)

{-# INLINE [0] fill #-}
fill :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Stream a -> m (v (PrimState m) a)
fill dest str = fillM dest (S.liftStream str)

{-# RULES
	"unstream/stream" forall xs . unstream (stream xs) = xs;
	"stream/unstream" forall str . stream (unstream str) = str;
	"unstream" [0] forall str . unstream str = create $ do
		!dest <- new $! S.length $! str
		S.imapM_ (write dest) (S.liftStream str)
		return dest;
	#-}