{-# LANGUAGE TemplateHaskell, PatternGuards, TypeFamilies, MultiParamTypeClasses, FlexibleContexts,  TypeOperators, UndecidableInstances #-}

module Data.TrieMap.Regular.ProdMap() where

import Data.TrieMap.Regular.Class
import Data.TrieMap.Regular.Base
import Data.TrieMap.Regular.Eq
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative
import Data.TrieMap.Sized
-- import Data.TrieMap.Regular.TH

import Control.Applicative
import Control.Arrow

import Data.Maybe
import Data.Monoid
import Data.Sequence (Seq, (|>))
import qualified Data.Sequence as Seq
import Data.Foldable

newtype PMap m1 (m2 :: * -> * -> *) k a = PMap (m1 k (m2 k a))
type instance TrieMapT (f :*: g) = PMap (TrieMapT f) (TrieMapT g)
type instance TrieMap ((f :*: g) r) = TrieMapT (f :*: g) r

lastIx :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a -> TrieMapT f k a -> Int
lastIx s m = fromMaybe (sizeT s m) (getLast (aboutT (\ _ a -> return $ sizeT s m - s a) m))

--maybe (sizeT s m) fst (getLast (extractMaxT s (\ _ a -> (sizeT s m - s a, Just a)) m))

instance (TrieKeyT f m1, TrieKeyT g m2, TrieKey k (TrieMap k)) =>
	TrieKey ((f :*: g) k) (PMap m1 m2 k) where
	emptyM = emptyT
	nullM = nullT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	alterM = alterT
	alterLookupM = alterLookupT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractM = extractT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT

instance (TrieKeyT f m1, TrieKeyT g m2) => TrieKeyT (f :*: g) (PMap m1 m2) where
	emptyT = PMap emptyT
	nullT (PMap m) = nullT m
	sizeT s (PMap m) = sizeT (sizeT s) m
	lookupT (a :*: b) (PMap m) = lookupT a m >>= lookupT b
	lookupIxT s (a :*: b) (PMap m) = case lookupIxT (sizeT s) a m of
		(lb, x, ub) -> let lookupX = do	Asc i' a' m' <- x
						let (lb', x', ub') = lookupIxT s b m'
						let f = onKeyA (a' :*:) . onIndexA (i' +)
						return (f <$> lb', f <$> x', f <$> ub')
			in ((do	Asc iL aL mL <- lb
				fmap (onKeyA (aL :*:) . onIndexA (iL +)) (getMax s mL)) <|>
			    (do	(lb', _, _) <- Last lookupX
			    	lb'),
			    (do	(_, x', _) <- lookupX
			    	x'),
			    (do	(_, _, ub') <- First lookupX
			    	ub') <|>
			    (do Asc iR aR mR <- ub
			    	fmap (onKeyA (aR :*:) . onIndexA (iR +)) (getMin s mR))) 
		where	getMin s m = aboutT (\ k a -> return (Asc 0 k a)) m
			getMax s m = aboutT (\ k a -> return (Asc (sizeT s m - s a) k a)) m
	assocAtT s i (PMap m) = case assocAtT (sizeT s) i m of
		(lb, x, ub) -> let lookupX = do	Asc i' a' m' <- x
						let (lb', x', ub') = assocAtT s (i - i') m'
						let f = onKeyA (a' :*:) . onIndexA (i' +)
						return (f <$> lb', f <$> x', f <$> ub')
			in ((do	Asc iL aL mL <- lb
				fmap (onKeyA (aL :*:) . onIndexA (iL +)) (getMax mL)) <|>
			    (do	(lb', _, _) <- Last lookupX
			    	lb'),
			    (do	(_, x', _) <- lookupX
			    	x'),
			    (do	(_, _, ub') <- First lookupX
			    	ub') <|>
			    (do	Asc iR aR mR <- ub
			    	fmap (onKeyA (aR :*:) . onIndexA (iR +)) (getMin mR)))
		where	getMin m = aboutT (\ k a -> return (Asc 0 k a)) m
			getMax m = aboutT (\ k a -> return (Asc (sizeT s m - s a) k a)) m
-- 	updateAtT s r f i (PMap m) = PMap (updateAtT (sizeT s) r g i m) where
-- 		g iA a m'
-- 			| not r && i < iA
-- 				= guardNullT (alterMinT s (f iA . (a :*:)) m')
-- 			| r && i >= iA + lastIx s m'
-- 				= guardNullT (alterMaxT s (f (lastIx s m') . (a :*:)) m')
-- 			| otherwise
-- 				= guardNullT (updateAtT s r (\ i' -> f (iA + i') . (a :*:)) (i - iA) m')
	alterT s f (a :*: b) (PMap m) = PMap (alterT (sizeT s) g a m) where
		g = guardNullT . alterT s f b . fromMaybe emptyT
	alterLookupT s f (a :*: b) (PMap m) = PMap <$> alterLookupT (sizeT s) g a m where
		g = fmap guardNullT . alterLookupT s f b . fromMaybe emptyT
	traverseWithKeyT s f (PMap m) = PMap <$> traverseWithKeyT (sizeT s) g m where
		g a = traverseWithKeyT s (\ b -> f (a :*: b))
	foldWithKeyT f (PMap m) = foldWithKeyT g m where
		g a = foldWithKeyT (\ b -> f (a :*: b))
	foldlWithKeyT f (PMap m) = foldlWithKeyT g m where
		g a z m = foldlWithKeyT (\ b -> f (a :*: b)) m z
	mapEitherT s1 s2 f (PMap m) = (PMap *** PMap) (mapEitherT (sizeT s1) (sizeT s2) g m) where
		g a = (guardNullT *** guardNullT) . mapEitherT s1 s2 (\ b -> f (a :*: b))
	splitLookupT s f (a :*: b) (PMap m) = PMap `sides` splitLookupT (sizeT s) g a m where
		g = sides guardNullT . splitLookupT s f b
	unionT s f (PMap m1) (PMap m2) = PMap (unionT (sizeT s) (\ a -> guardNullT .: unionT s (\ b -> f (a :*: b))) m1 m2)
	isectT s f (PMap m1) (PMap m2) = PMap (isectT (sizeT s) (\ a -> guardNullT .: isectT s (\ b -> f (a :*: b))) m1 m2)
	diffT s f (PMap m1) (PMap m2) = PMap (diffT (sizeT s) (\ a -> guardNullT .: diffT s (\ b -> f (a :*: b))) m1 m2)
	extractT s f (PMap m) = fmap PMap <$> extractT (sizeT s) g m where
		g a = fmap guardNullT <.> extractT s (\ b -> f (a :*: b))
-- 	extractMinT s f (PMap m) = second PMap <$> extractMinT (sizeT s) g m where
-- 		g a = second guardNullT . fromJust . getFirst . extractMinT s (f . (a :*:))
-- 	extractMaxT s f (PMap m) = second PMap <$> extractMaxT (sizeT s) g m where
-- 		g a = second guardNullT . fromJust . getLast . extractMaxT s (f . (a :*:))
-- 	alterMinT s f (PMap m) = PMap (alterMinT (sizeT s) (\ a -> guardNullT . alterMinT s (\ b -> f (a :*: b))) m)
-- 	alterMaxT s f (PMap m) = PMap (alterMaxT (sizeT s) (\ a -> guardNullT . alterMaxT s (\ b -> f (a :*: b))) m)
	isSubmapT (<=) (PMap m1) (PMap m2) = isSubmapT (isSubmapT (<=)) m1 m2 
	fromListT s f xs = PMap (mapWithKeyT (sizeT s) (\ a -> fromListT s (\ b -> f (a :*: b))) 
		(fromListT (const 1) (const (++)) (breakFst xs)))
	fromAscListT s f xs = PMap (fromDistAscListT (sizeT s)
		[(a, fromAscListT s (\ b -> f (a :*: b)) ys) | (a, ys) <- breakFst xs])
	
breakFst :: (EqT f, Eq k) => [((f :*: g) k, a)] -> [(f k, [(g k, a)])]
breakFst [] = []
breakFst ((a :*: b, v):xs) = breakFst' a (Seq.singleton (b, v)) xs where
   	breakFst' a vs ((a' :*: b', v):xs)
		| a `eqT` a'	= breakFst' a (vs |> (b', v)) xs
		| otherwise	= (a, toList vs):breakFst' a' (Seq.singleton (b', v)) xs
	breakFst' a vs [] = [(a, toList vs)]