{-# LANGUAGE TypeOperators, FlexibleInstances, FlexibleContexts, UndecidableInstances, TypeFamilies, MultiParamTypeClasses #-}

module Data.TrieMap.MultiRec.ProdMap where

import Data.TrieMap.MultiRec.Class
import Data.TrieMap.MultiRec.Eq
import Data.TrieMap.MultiRec.Sized
import Data.TrieMap.Applicative
import Data.TrieMap.TrieKey

import Control.Applicative
import Control.Arrow

import Data.Maybe
import Data.Foldable
import Data.Sequence ((|>))
import qualified Data.Sequence as Seq

import Generics.MultiRec

newtype ProdMap (phi :: * -> *) m1 (m2 :: (* -> *) -> (* -> *) -> * -> *) (r :: * -> *) (a :: * -> *) ix = PMap (m1 r (m2 r a) ix)
type instance HTrieMapT phi (f :*: g) = ProdMap phi (HTrieMapT phi f) (HTrieMapT phi g)
type instance HTrieMap phi ((f :*: g) r) = HTrieMapT phi (f :*: g) r

-- instance (HTrieKey phi (f r), HTrieKey phi (g r)) => HTrieKey phi ((f :*: g) r) where
-- 	emptyH pf ~(a :*: b) = PMap (emptyH pf a)
-- 	nullH pf ~(a :*: b) (PMap m) = nullH pf a m
-- 	lookupH pf (a :*: b) (PMap m) = lookupH pf a m >>= lookupH pf b
-- 	alterH pf f (a :*: b) (PMap m) = PMap (alterH pf (guardNull . g) a m) where
-- 		g = alterH pf f b . fromMaybe (emptyH pf b)
-- 		guardNull m
-- 			| nullH pf b m	= Nothing
-- 			| otherwise	= Just m
-- 	traverseWithKeyH pf f (PMap m) = 
-- 		PMap <$> traverseWithKeyH pf (\ a -> traverseWithKeyH pf (\ b -> f (a :*: b))) m
-- 	foldWithKeyH pf f (PMap m) = 
-- 		foldWithKeyH pf (\ a -> foldWithKeyH pf (\ b -> f (a :*: b))) m

instance (HTrieKeyT phi f m1, m1 ~ HTrieMapT phi f, HTrieKeyT phi g m2, m2 ~ HTrieMapT phi g) => 
		HTrieKeyT phi (f :*: g) (ProdMap phi m1 m2) where
	emptyT = PMap . emptyT
	nullT pf (PMap m) = nullT pf m
	sizeT s (PMap m) = sizeT (sizeT s) m
	lookupT pf (a :*: b) (PMap m) = lookupT pf a m >>= lookupT pf b
	lookupIxT pf s (a :*: b) (PMap m) = do
		(iA, m') <- lookupIxT pf (sizeT s) a m
		(iB, v) <- lookupIxT pf s b m'
		return (iA + iB, v)
	assocAtT pf s i (PMap m) = case assocAtT pf (sizeT s) i m of
		(iA, a, m') -> case assocAtT pf s (i - iA) m' of
			(iB, b, v) -> (iA + iB, a :*: b, v)
	updateAtT pf s f i (PMap m) = PMap (updateAtT pf (sizeT s) g i m) where
		g iA a = guardNullT pf . updateAtT pf s (\ iB b -> f (iA + iB) (a :*: b)) (i - iA)
	alterT pf s f (a :*: b) (PMap m) = PMap (alterT pf (sizeT s) (guardNullT pf . g) a m) where
		g = alterT pf s f b . fromMaybe (emptyT pf)
	traverseWithKeyT pf s f (PMap m) = 
		PMap <$> traverseWithKeyT pf (sizeT s) (\ a -> traverseWithKeyT pf s (\ b -> f (a :*: b))) m
	foldWithKeyT pf f (PMap m) =
		foldWithKeyT pf (\ a -> foldWithKeyT pf (\ b -> f (a :*: b))) m
	foldlWithKeyT pf f (PMap m) =
		foldlWithKeyT pf (\ a -> flip (foldlWithKeyT pf (\ b -> f (a :*: b)))) m
	mapEitherT pf s1 s2 f (PMap m) = (PMap *** PMap) (mapEitherT pf (sizeT s1) (sizeT s2) g m) where
		g a = (guardNullT pf *** guardNullT pf) . mapEitherT pf s1 s2 (\ b -> f (a :*: b))
	splitLookupT pf s f (a :*: b) (PMap m) = PMap `sides` splitLookupT pf (sizeT s) g a m where
		g = sides (guardNullT pf) . splitLookupT pf s f b
	unionT pf s f (PMap m1) (PMap m2) = PMap (unionT pf (sizeT s) g m1 m2) where
		g a = guardNullT pf .: unionT pf s (\ b -> f (a :*: b))
	isectT pf s f (PMap m1) (PMap m2) = PMap (isectT pf (sizeT s) g m1 m2) where
		g a = guardNullT pf .: isectT pf s (\ b -> f (a :*: b))
	diffT pf s f (PMap m1) (PMap m2) = PMap (diffT pf (sizeT s) g m1 m2) where
		g a = guardNullT pf .: diffT pf s (\ b -> f (a :*: b))
	extractMinT pf s (PMap m) = do
		((a, m1), m') <- extractMinT pf (sizeT s) m
		((b, v), m1') <- extractMinT pf s m1
		return ((a :*: b, v), PMap (maybe m' (\ m1' -> alterMinT pf (sizeT s) (\ _ _ -> Just m1') m) (guardNullT pf m1')))
	extractMaxT pf s (PMap m) = do
		((a, m1), m') <- extractMaxT pf (sizeT s) m
		((b, v), m1') <- extractMaxT pf s m1
		return ((a :*: b, v), PMap (maybe m' (\ m1' -> alterMaxT pf (sizeT s) (\ _ _ -> Just m1') m) (guardNullT pf m1')))
	alterMinT pf s f (PMap m) = PMap (alterMinT pf (sizeT s) g m) where
		g a = guardNullT pf . alterMinT pf s (\ b -> f (a :*: b))
	alterMaxT pf s f (PMap m) = PMap (alterMaxT pf (sizeT s) g m) where
		g a = guardNullT pf . alterMaxT pf s (\ b -> f (a :*: b))
	isSubmapT pf (<=) (PMap m1) (PMap m2) = isSubmapT pf (isSubmapT pf (<=)) m1 m2
	fromListT pf s f xs = PMap (mapWithKeyT pf (sizeT s) (\ a -> fromListT pf s (\ b -> f (a :*: b)) . unK0)
				(fromListT pf (const 1) (\ _ (K0 xs) (K0 ys) -> K0 (xs ++ ys))
					[(a, K0 ts) | (a, ts) <- breakFst pf xs]))
	fromAscListT pf s f xs = PMap (fromDistAscListT pf (sizeT s)
		[(a, fromAscListT pf s (\ b -> f (a :*: b)) ts) | (a, ts) <- breakFst pf xs])
	fromDistAscListT pf s xs = PMap (fromDistAscListT pf (sizeT s)
		[(a, fromDistAscListT pf s ts) | (a, ts) <- breakFst pf xs])

breakFst :: (HEq phi f, HEq0 phi r) => phi ix -> [((f :*: g) r ix, a ix)] -> [(f r ix, [(g r ix, a ix)])]
breakFst pf [] = []
breakFst pf ((a :*: b, x):xs) = breakFst' a (Seq.singleton (b, x)) xs where
	breakFst' a0 ts ((a :*: b, x):xs)
		| heqT pf a0 a	= breakFst' a0 (ts |> (b, x)) xs
		| otherwise	= (a0, toList ts):breakFst' a (Seq.singleton (b,x)) xs
	breakFst' a ts [] = [(a, toList ts)]

instance (HTrieKeyT phi f m1, m1 ~ HTrieMapT phi f, HTrieKeyT phi g m2, m2 ~ HTrieMapT phi g,
		HTrieKey phi r (HTrieMap phi r)) => HTrieKey phi ((f :*: g) r) (ProdMap phi m1 m2 r) where
	emptyH = emptyT
	nullH = nullT
	sizeH = sizeT
	lookupH = lookupT
	lookupIxH = lookupIxT
	assocAtH = assocAtT
	updateAtH = updateAtT
	alterH = alterT
	traverseWithKeyH = traverseWithKeyT
	foldWithKeyH = foldWithKeyT
	foldlWithKeyH = foldlWithKeyT
	mapEitherH = mapEitherT
	splitLookupH = splitLookupT
	unionH = unionT
	isectH = isectT
	diffH = diffT
	alterMinH = alterMinT
	alterMaxH = alterMaxT
	extractMinH = extractMinT
	extractMaxH = extractMaxT
	isSubmapH = isSubmapT
	fromListH = fromListT
	fromAscListH = fromAscListT
	fromDistAscListH = fromDistAscListT