{-# LANGUAGE UnboxedTuples, TupleSections, PatternGuards, TypeFamilies #-}

module Data.TrieMap.ProdMap () where

import Data.TrieMap.Sized
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative

import Control.Applicative

import Data.Foldable

import Data.Sequence ((|>))
import qualified Data.Sequence as Seq

instance (TrieKey k1, TrieKey k2) => TrieKey (k1, k2) where
	newtype TrieMap (k1, k2) a = PMap (TrieMap k1 (TrieMap k2 a))
	data Hole (k1, k2) a = PHole (Hole k1 (TrieMap k2 a)) (Hole k2 a)

	emptyM = PMap emptyM
	singletonM (k1, k2) a = PMap (singletonM k1 (singletonM k2 a))
	nullM (PMap m) = nullM m
	sizeM (PMap m) = sizeM m
	lookupM (k1, k2) (PMap m) = lookupM k1 m >>= lookupM k2
	traverseWithKeyM f (PMap m) = PMap <$> traverseWithKeyM (\ a -> traverseWithKeyM (f . (a,))) m
	foldrWithKeyM f (PMap m) = foldrWithKeyM (\ a -> foldrWithKeyM (f . (a,))) m
	foldlWithKeyM f (PMap m) = foldlWithKeyM (\ a -> flip (foldlWithKeyM (f . (a,)))) m
	mapWithKeyM f (PMap m) = PMap (mapWithKeyM (\ a -> mapWithKeyM (f . (a,))) m)
	mapMaybeM f (PMap m) = PMap (mapMaybeM g m) where
		g a = guardNullM . mapMaybeM (f . (a,))
	mapEitherM f (PMap m) = both PMap PMap (mapEitherM g) m where
		g a m = both guardNullM guardNullM (mapEitherM (f . (a,))) m
	isSubmapM (<=) (PMap m1) (PMap m2) = isSubmapM (isSubmapM (<=)) m1 m2
	unionM f (PMap m1) (PMap m2) = PMap (unionM (\ a -> guardNullM .: unionM (f . (a,))) m1 m2)
	isectM f (PMap m1) (PMap m2) = PMap (isectM (\ a -> guardNullM .: isectM (f . (a,))) m1 m2)
	diffM f (PMap m1) (PMap m2) = PMap (diffM (\ a -> guardNullM .: diffM (f . (a,))) m1 m2)
	fromListM f xs = PMap (mapWithKeyM (\ a (Elem xs) -> fromListM (f . (a,)) xs)
		(fromListM (\ _ (Elem xs) (Elem ys) -> Elem (xs ++ ys)) (breakFst xs)))
	fromAscListM f xs = PMap (fromDistAscListM
		[(a, fromAscListM (f . (a,)) ys) | (a, Elem ys) <- breakFst xs])

	singleHoleM (k1, k2) = PHole (singleHoleM k1) (singleHoleM k2)
	keyM (PHole hole1 hole2) = (keyM hole1, keyM hole2)
	assignM v (PHole hole1 hole2) = PMap (assignM (assignM v hole2) hole1)
	clearM (PHole hole1 hole2) = PMap (fillHoleM (guardNullM (clearM hole2)) hole1)
	beforeM a (PHole hole1 hole2) 
		= PMap (beforeM (guardNullM (beforeM a hole2)) hole1)
	afterM a (PHole hole1 hole2)
		= PMap (afterM (guardNullM (afterM a hole2)) hole1)
	searchM (k1, k2) (PMap m) = case searchM k1 m of
		(# Nothing, hole1 #)	-> (# Nothing, PHole hole1 (singleHoleM k2) #)
		(# Just m', hole1 #)	-> onUnboxed (PHole hole1) (searchM k2) m'
	indexM i (PMap m)
		| (# i', m', hole1 #) <- indexM i m,
		  (# i'', v, hole2 #) <- indexM i' m'
		  = (# i'', v, PHole hole1 hole2 #)
	extractHoleM (PMap m) = do
		(m', hole1) <- extractHoleM m
		(v, hole2) <- extractHoleM m'
		return (v, PHole hole1 hole2)

breakFst :: Eq k1 => [((k1, k2), a)] -> [(k1, Elem [(k2, a)])]
breakFst [] = []
breakFst (((a, b),v):xs) = breakFst' a (Seq.singleton (b, v)) xs where
	breakFst' a vs (((a', b'), v'):xs)
		| a == a'	= breakFst' a' (vs |> (b', v')) xs
		| otherwise	= (a, Elem $ toList vs):breakFst' a' (Seq.singleton (b', v')) xs
	breakFst' a vs [] = [(a, Elem $ toList vs)]