{-# LANGUAGE PatternGuards, TemplateHaskell, TypeOperators, FlexibleInstances, FlexibleContexts, UndecidableInstances, TypeFamilies, MultiParamTypeClasses #-}

module Data.TrieMap.MultiRec.ProdMap () where

import Data.TrieMap.MultiRec.Class
import Data.TrieMap.MultiRec.Eq
-- import Data.TrieMap.MultiRec.Ord
import Data.TrieMap.MultiRec.Sized
-- import Data.TrieMap.MultiRec.TH
-- import Data.TrieMap.Regular.Eq
-- import Data.TrieMap.Regular.Ord
-- import Data.TrieMap.Regular.Base (O(..))
import Data.TrieMap.Applicative
import Data.TrieMap.TrieKey
-- import Data.TrieMap.Rep
-- import Data.TrieMap.Rep.TH

import Control.Applicative
import Control.Arrow

import Data.Maybe
import Data.Monoid
import Data.Foldable
import Data.Sequence ((|>))
import qualified Data.Sequence as Seq

import Generics.MultiRec

newtype ProdMap (phi :: * -> *) f g (r :: * -> *) ix a = PMap (HTrieMapT phi f r ix (HTrieMapT phi g r ix a))
type instance HTrieMapT phi (f :*: g) = ProdMap phi f g--(HTrieMapT phi f) (HTrieMapT phi g)
-- type instance HTrieMap phi ((f :*: g) r) = HTrieMapT phi (f :*: g) r

-- type instance RepH (ProdMap phi f g r ix) = RepH (HTrieMapT phi f r ix) `O` RepH (HTrieMapT phi g r ix)
-- type instance Rep (ProdMap phi f g r ix a) = RepH (ProdMap phi f g r ix) (Rep a)

-- -- $(genRepH [d|
-- 	instance (ReprH (HTrieMapT phi f r ix), ReprH (HTrieMapT phi g r ix)) =>
-- 			ReprH (ProdMap phi f g r ix) where
-- 		toRepH (PMap m) = O (fmap toRepH (toRepH m))
-- 		fromRepH (O m) = PMap (fromRepH (fmap fromRepH m)) |] )

maxIx :: (HTrieKeyT phi f (HTrieMapT phi f), HTrieKey phi r (HTrieMap phi r)) => phi ix -> HSized phi a -> 
		HTrieMapT phi f r ix a -> Int
maxIx pf s m = fromMaybe (sizeH pf s m) (getFirst (aboutH pf (\ _ a -> return (sizeH pf s m - s a)) m))

instance (HTrieKeyT phi f (HTrieMapT phi f), HTrieKeyT phi g (HTrieMapT phi g)) => 
	HTrieKeyT phi (f :*: g) (ProdMap phi f g) where
	emptyH = PMap . emptyH
	nullH pf (PMap m) = nullH pf m
	sizeH pf s (PMap m) = sizeH pf (sizeH pf s) m
	lookupH pf (a :*: b) (PMap m) = lookupH pf a m >>= lookupH pf b
	lookupIxH pf s (a :*: b) (PMap m) = case lookupIxH pf (sizeH pf s) a m of
		(lb, x, rb) -> let lookupX = do	Asc i a' m' <- x
						let (lb', x', rb') = lookupIxH pf s b m'
						let f = onIndexA (i +) . onKeyA (a' :*:)
						return (f <$> lb', f <$> x', f <$> rb')
		   in 	((do	Asc iA aL mL <- lb
				fmap (onIndexA (iA +) . onKeyA (aL :*:)) (getLast pf s mL)) <|>
			 (do	(lb', _, _) <- Last lookupX
				lb'),
			 (do	(_, x', _) <- lookupX
				x'),
			 (do	(_, _, rb') <- First lookupX
				rb') <|>
			 (do	Asc iA aR mR <- rb
			  	fmap (onIndexA (iA +) . onKeyA (aR :*:)) (getFirst pf s mR)))
		where 	getLast pf s m = aboutH pf (\ k a -> return (Asc (sizeH pf s m - s a) k a)) m
			getFirst pf s m = aboutH pf (\ k a -> return (Asc 0 k a)) m
	assocAtH pf s i (PMap m) = case assocAtH pf (sizeH pf s) i m of
		(lb, x, rb) -> let lookupX = do	Asc i' a' m' <- x
						let (lb', x', rb') = assocAtH pf s (i - i') m'
						let f = onIndexA (i' +) . onKeyA (a' :*:)
						return (f <$> lb', f <$> x', f <$> rb')
			in ((do	Asc iA aL mL <- lb
				fmap (onIndexA (iA +) . onKeyA (aL :*:)) (getLast pf s mL)) <|>
			    (do	(lb', _, _) <- Last lookupX
			    	lb'),
			    (do	(_, x', _) <- lookupX
			    	x'),
			    (do	(_, _, rb') <- First lookupX
			    	rb') <|>
			    (do	Asc iA aR mR <- rb
			    	fmap (onIndexA (iA +) . onKeyA (aR :*:)) (getFirst pf s mR)))
		where 	getLast pf s m = aboutH pf (\ k a -> return (Asc (sizeH pf s m - s a) k a)) m
			getFirst pf s m = aboutH pf (\ k a -> return (Asc 0 k a)) m
-- 	updateAtH pf s r f i (PMap m) = PMap (updateAtH pf (sizeH pf s) r g i m) where
-- 		g iA a m 
-- 			| i >= iA && i <= iA + maxIx pf s m
-- 					= (guardNullH pf . updateAtH pf s r (\ iB b -> f (iA + iB) (a :*: b)) (i - iA)) m
-- 				| i < iA
-- 					= guardNullH pf $
-- 						alterMaxH pf s (\ b v -> f (iA + sizeH pf s m - s v) (a :*: b) v) m
-- 				| otherwise
-- 					= guardNullH pf $ alterMinH pf s (f iA . (a :*:)) m
	alterH pf s f (a :*: b) (PMap m) = PMap (alterH pf (sizeH pf s) (guardNullH pf . g) a m) where
		g = alterH pf s f b . fromMaybe (emptyH pf)
	alterLookupH pf s f (a :*: b) (PMap m) = PMap <$> alterLookupH pf (sizeH pf s) g a m where
		g = fmap (guardNullH pf) . alterLookupH pf s f b . fromMaybe (emptyH pf)
	traverseWithKeyH pf s f (PMap m) = 
		PMap <$> traverseWithKeyH pf (sizeH pf s) (\ a -> traverseWithKeyH pf s (\ b -> f (a :*: b))) m
	foldWithKeyH pf f (PMap m) =
		foldWithKeyH pf (\ a -> foldWithKeyH pf (\ b -> f (a :*: b))) m
	foldlWithKeyH pf f (PMap m) =
		foldlWithKeyH pf (\ a -> flip (foldlWithKeyH pf (\ b -> f (a :*: b)))) m
	mapEitherH pf s1 s2 f (PMap m) = (PMap *** PMap) (mapEitherH pf (sizeH pf s1) (sizeH pf s2) g m) where
		g a = (guardNullH pf *** guardNullH pf) . mapEitherH pf s1 s2 (\ b -> f (a :*: b))
	splitLookupH pf s f (a :*: b) (PMap m) = PMap `sides` splitLookupH pf (sizeH pf s) g a m where
		g = sides (guardNullH pf) . splitLookupH pf s f b
	unionH pf s f (PMap m1) (PMap m2) = PMap (unionH pf (sizeH pf s) g m1 m2) where
		g a = guardNullH pf .: unionH pf s (\ b -> f (a :*: b))
	isectH pf s f (PMap m1) (PMap m2) = PMap (isectH pf (sizeH pf s) g m1 m2) where
		g a = guardNullH pf .: isectH pf s (\ b -> f (a :*: b))
	diffH pf s f (PMap m1) (PMap m2) = PMap (diffH pf (sizeH pf s) g m1 m2) where
		g a = guardNullH pf .: diffH pf s (\ b -> f (a :*: b))
	extractH pf s f (PMap m) = fmap PMap <$> extractH pf (sizeH pf s) g m where
		g a = fmap (guardNullH pf) <.> extractH pf s (\ b -> f (a :*: b))
-- 	extractMinH pf s f (PMap m) = second PMap <$> extractMinH pf (sizeH pf s) g m where 
-- 			g a m1 = fromJust $ getFirst $ second (guardNullH pf) <$> extractMinH pf s (f . (a :*:)) m1
-- 		extractMaxH pf s f (PMap m) = second PMap <$> extractMaxH pf (sizeH pf s) g m where 
-- 			g a m1 = fromJust $ getLast $ second (guardNullH pf) <$> extractMaxH pf s (f . (a :*:)) m1
-- 		alterMinH pf s f (PMap m) = PMap (alterMinH pf (sizeH pf s) g m) where
-- 			g a = guardNullH pf . alterMinH pf s (\ b -> f (a :*: b))
-- 		alterMaxH pf s f (PMap m) = PMap (alterMaxH pf (sizeH pf s) g m) where
-- 			g a = guardNullH pf . alterMaxH pf s (\ b -> f (a :*: b))
	isSubmapH pf (<=) (PMap m1) (PMap m2) = isSubmapH pf (isSubmapH pf (<=)) m1 m2
	fromListH pf s f xs = PMap (mapWithKeyH pf (sizeH pf s) (\ a -> fromListH pf s (\ b -> f (a :*: b)))
				(fromListH pf (const 1) (\ _ (xs) (ys) -> (xs ++ ys))
					[(a, ts) | (a, ts) <- breakFst pf xs]))
	fromAscListH pf s f xs = PMap (fromDistAscListH pf (sizeH pf s)
		[(a, fromAscListH pf s (\ b -> f (a :*: b)) ts) | (a, ts) <- breakFst pf xs])
	fromDistAscListH pf s xs = PMap (fromDistAscListH pf (sizeH pf s)
		[(a, fromDistAscListH pf s ts) | (a, ts) <- breakFst pf xs])

breakFst :: (HEq phi f, HEq0 phi r) => phi ix -> [((f :*: g) r ix, a)] -> [(f r ix, [(g r ix, a)])]
breakFst pf [] = []
breakFst pf ((a :*: b, x):xs) = breakFst' a (Seq.singleton (b, x)) xs where
	breakFst' a0 ts ((a :*: b, x):xs)
		| heqT pf a0 a
				= breakFst' a0 (ts |> (b, x)) xs
		| otherwise	= (a0, toList ts):breakFst' a (Seq.singleton (b,x)) xs
	breakFst' a ts [] = [(a, toList ts)]