{-# LANGUAGE TypeFamilies, MultiParamTypeClasses, FlexibleContexts,  TypeOperators, UndecidableInstances #-}

module Data.TrieMap.Regular.ProdMap() where

import Data.TrieMap.Regular.Class
import Data.TrieMap.Regular.Base
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative

import Control.Applicative
import Control.Arrow

import Data.Maybe

newtype PMap m1 (m2 :: * -> (* -> *) -> * -> *) k (a :: * -> *) ix = PMap (m1 k (m2 k a) ix)
type instance TrieMapT (f :*: g) = PMap (TrieMapT f) (TrieMapT g)
type instance TrieMap ((f :*: g) r) = TrieMapT (f :*: g) r

type instance PF (PMap m1 m2 k a ix) = PF (m1 k (m2 k a) ix)

instance (Regular (m1 k (m2 k a) ix), Functor (PF (m1 k (m2 k a) ix))) => Regular (PMap m1 m2 k a ix) where
	from (PMap m) = fmap PMap (from m)
	to = PMap . to . fmap (\ (PMap m) -> m)

instance (TrieKeyT f m1, TrieKeyT g m2) => TrieKeyT (f :*: g) (PMap m1 m2) where
	emptyT = PMap emptyT
	nullT (PMap m) = nullT m
	sizeT s (PMap m) = sizeT (sizeT s) m
	lookupT (a :*: b) (PMap m) = lookupT a m >>= lookupT b
	lookupIxT s (a :*: b) (PMap m) = do
		(iA, m') <- lookupIxT (sizeT s) a m
		(iB, v) <- lookupIxT s b m'
		return (iA + iB, v)
	assocAtT s i (PMap m) = case assocAtT (sizeT s) i m of
		(iA, a, m') -> case assocAtT s (i - iA) m' of
			(iB, b, v) -> (iA + iB, a :*: b, v)
	updateAtT s f i (PMap m) = PMap (updateAtT (sizeT s) g i m) where
		g iA a = guardNullT . updateAtT s (\ iB b -> f (iA + iB) (a :*: b)) (i - iA)
	alterT s f (a :*: b) (PMap m) = PMap (alterT (sizeT s) g a m) where
		g = guardNullT . alterT s f b . fromMaybe emptyT
	traverseWithKeyT s f (PMap m) = PMap <$> traverseWithKeyT (sizeT s) g m where
		g a = traverseWithKeyT s (\ b -> f (a :*: b))
	foldWithKeyT f (PMap m) = foldWithKeyT g m where
		g a = foldWithKeyT (\ b -> f (a :*: b))
	foldlWithKeyT f (PMap m) = foldlWithKeyT g m where
		g a z m = foldlWithKeyT (\ b -> f (a :*: b)) m z
	mapEitherT s1 s2 f (PMap m) = (PMap *** PMap) (mapEitherT (sizeT s1) (sizeT s2) g m) where
		g a = (guardNullT *** guardNullT) . mapEitherT s1 s2 (\ b -> f (a :*: b))
	splitLookupT s f (a :*: b) (PMap m) = PMap `sides` splitLookupT (sizeT s) g a m where
		g = sides guardNullT . splitLookupT s f b
	unionT s f (PMap m1) (PMap m2) = PMap (unionT (sizeT s) (\ a -> guardNullT .: unionT s (\ b -> f (a :*: b))) m1 m2)
	isectT s f (PMap m1) (PMap m2) = PMap (isectT (sizeT s) (\ a -> guardNullT .: isectT s (\ b -> f (a :*: b))) m1 m2)
	diffT s f (PMap m1) (PMap m2) = PMap (diffT (sizeT s) (\ a -> guardNullT .: diffT s (\ b -> f (a :*: b))) m1 m2)
	extractMinT s (PMap m) = do
		((a, m1), m') <- extractMinT (sizeT s) m
		((b, v), m1') <- extractMinT s m1
		return ((a :*: b, v), PMap (maybe m' (\ _ -> alterMinT (sizeT s) (\ _ _ -> Just m1') m) (guardNullT m1')))
	extractMaxT s (PMap m) = do
		((a, m1), m') <- extractMaxT (sizeT s) m
		((b, v), m1') <- extractMaxT s m1
		return ((a :*: b, v), PMap (maybe m' (\ _ -> alterMaxT (sizeT s) (\ _ _ -> Just m1') m) (guardNullT m1')))
	alterMinT s f (PMap m) = PMap (alterMinT (sizeT s) (\ a -> guardNullT . alterMinT s (\ b -> f (a :*: b))) m)
	alterMaxT s f (PMap m) = PMap (alterMaxT (sizeT s) (\ a -> guardNullT . alterMaxT s (\ b -> f (a :*: b))) m)
	isSubmapT (<=) (PMap m1) (PMap m2) = isSubmapT (isSubmapT (<=)) m1 m2

instance (TrieKeyT f m1, TrieKeyT g m2, TrieKey k (TrieMap k)) => TrieKey ((f :*: g) k) (PMap m1 m2 k) where
	emptyM = emptyT
	nullM = nullT
	sizeM = sizeT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	updateAtM = updateAtT
	alterM = alterT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractMinM = extractMinT
	extractMaxM = extractMaxT
	alterMinM = alterMinT
	alterMaxM = alterMaxT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT