{-# LANGUAGE FlexibleContexts, UndecidableInstances, MultiParamTypeClasses, TypeFamilies #-}

module Data.TrieMap.ProdMap () where

import Data.TrieMap.TrieKey
-- import Data.TrieMap.Sized
import Data.TrieMap.Applicative
import Data.TrieMap.Regular.Class
-- import Data.TrieMap.Regular.TH

import Control.Applicative
import Control.Arrow

import Data.Maybe
import Data.Monoid
import Data.Foldable

import Data.Sequence (Seq, (|>))
import qualified Data.Sequence as Seq

newtype PMap m1 k2 a = PMap (m1 (TrieMap k2 a))
type instance TrieMapT ((,) a) = PMap (TrieMap a)
type instance TrieMap (a, b) = PMap (TrieMap a) b
-- type instance TrieMap (a, b) = PMap (TrieMap a) (TrieMap b)

instance (TrieKey a m, TrieKey b (TrieMap b)) => TrieKey (a, b) (PMap m b) where
	emptyM = emptyT
	nullM = nullT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	alterM = alterT
	alterLookupM = alterLookupT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractM = extractT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT

instance TrieKey k1 m1 => TrieKeyT ((,) k1) (PMap m1) where
	emptyT = PMap emptyM
	nullT (PMap m) = nullM m
	sizeT s (PMap m) = sizeM (sizeM s) m
	lookupT (k1, k2) (PMap m) = lookupM k1 m >>= lookupM k2
	lookupIxT s (a, b) (PMap m) = case lookupIxM (sizeM s) a m of
		(lb, x, ub) -> let lookupX = do	Asc i1 a' m' <- x
						return (onIndex (i1 +) (onKey ((,) a') (lookupIxM s b m')))
			in ((do	Asc iL aL mL <- lb
				aboutM (\ bL v -> return (Asc (iL + sizeM s mL - s v) (aL, bL) v)) mL) <|>
			    (do	(lb', _, _) <- Last lookupX
			    	lb'),
			    (do	(_, x', _) <- lookupX
			    	x'),
			    (do	(_, _, ub') <- First lookupX
			    	ub') <|>
			    (do	Asc iU aU mU <- ub
			    	aboutM (\ bU -> return . Asc iU (aU, bU)) mU))
	assocAtT s i (PMap m) = case assocAtM (sizeM s) i m of
		(lb, x, ub) -> let lookupX = do	Asc i1 a' m' <- x
						return (onIndex (i1 +) (onKey ((,) a') (assocAtM s (i - i1) m')))
			in ((do	Asc iL aL mL <- lb
				aboutM (\ bL v -> return (Asc (iL + sizeM s mL - s v) (aL, bL) v)) mL) <|>
			    (do	(lb', _, _) <- Last lookupX
			    	lb'),
			    (do	(_, x', _) <- lookupX
			    	x'),
			    (do	(_, _, ub') <- First lookupX
			    	ub') <|>
			    (do	Asc iU aU mU <- ub
			    	aboutM (\ bU -> return . Asc iU (aU, bU)) mU))
-- 	updateAtM
	alterT s f (a, b) (PMap m) = PMap (alterM (sizeM s) g a m) where
		g = guardNullM . alterM s f b . fromMaybe emptyM
	alterLookupT s f (a, b) (PMap m) = PMap <$> alterLookupM (sizeM s) g a m where
		g = fmap guardNullM . alterLookupM s f b . fromMaybe emptyM
	traverseWithKeyT s f (PMap m) = PMap <$> traverseWithKeyM (sizeM s) (\ a -> traverseWithKeyM s (\ b -> f (a, b))) m
	foldWithKeyT f (PMap m) = foldWithKeyM (\ a -> foldWithKeyM (\ b -> f (a, b))) m
	foldlWithKeyT f (PMap m) = foldlWithKeyM (\ a -> flip (foldlWithKeyM (\ b -> f (a, b)))) m
	mapEitherT s1 s2 f (PMap m) = (PMap *** PMap) (mapEitherM (sizeM s1) (sizeM s2) g m) where
		g a = (guardNullM *** guardNullM) . mapEitherM s1 s2 (\ b -> f (a, b))
	splitLookupT s f (a, b) (PMap m) = PMap `sides` splitLookupM (sizeM s) g a m where
		g = sides guardNullM . splitLookupM s f b
	isSubmapT (<=) (PMap m1) (PMap m2) = isSubmapM (isSubmapM (<=)) m1 m2
	unionT s f (PMap m1) (PMap m2) = PMap (unionM (sizeM s) (\ a -> guardNullM .: unionM s (\ b -> f (a, b))) m1 m2)
	isectT s f (PMap m1) (PMap m2) = PMap (isectM (sizeM s) (\ a -> guardNullM .: isectM s (\ b -> f (a, b))) m1 m2)
	diffT s f (PMap m1) (PMap m2) = PMap (diffM (sizeM s) (\ a -> guardNullM .: diffM s (\ b -> f (a, b))) m1 m2)
	extractT s f (PMap m) = fmap PMap <$> extractM (sizeM s) g m where
		g a = fmap guardNullM <.> extractM s (\ b -> f (a, b))
-- 	extractMinT s f (PMap m) = second PMap <$> extractMinM (sizeM s) g m where
-- 		g a = second guardNullM . fromJust . getFirst . extractMinM s (\ b -> f (a, b))
-- 	extractMaxT s f (PMap m) = second PMap <$> extractMaxM (sizeM s) g m where
-- 		g a = second guardNullM . fromJust . getLast . extractMaxM s (\ b -> f (a, b))
	fromListT s f xs = PMap (mapWithKeyM (sizeM s) (\ a -> fromListM s (\ b -> f (a, b)))
		(fromListM (const 1) (const (++)) (breakFst xs)))
	fromAscListT s f xs = PMap (fromDistAscListM (sizeM s)
		[(a, fromAscListM s (\ b -> f (a, b)) ys) | (a, ys) <- breakFst xs])

--    aboutMin :: TrieKey k (TrieMap k) => Sized a -> (k -> a -> x) -> TrieMap k a -> First x
--    aboutMin s f m = fst <$> extractMinM s (\ k a -> (f k a, Nothing)) m
-- 
--    aboutMax :: TrieKey k (TrieMap k) => Sized a -> (k -> a -> x) -> TrieMap k a -> Last x
--    aboutMax s f m = fst <$> extractMaxM s (\ k a -> (f k a, Nothing)) m

breakFst :: Eq k1 => [((k1, k2), a)] -> [(k1, [(k2, a)])]
breakFst [] = []
breakFst (((a, b),v):xs) = breakFst' a (Seq.singleton (b, v)) xs where
	breakFst' a vs (((a', b'), v'):xs)
		| a == a'	= breakFst' a' (vs |> (b', v')) xs
		| otherwise	= (a, toList vs):breakFst' a' (Seq.singleton (b', v')) xs
	breakFst' a vs [] = [(a, toList vs)]
	{-
guardNullM :: TrieKey k (TrieMap k) => TrieMap k a -> Maybe (TrieMap k a)
guardNullM m 
	| nullM m	= Nothing
	| otherwise	= Just m-}