{-# LANGUAGE TemplateHaskell, PatternGuards, UndecidableInstances, FlexibleContexts, TypeOperators, TypeFamilies, MultiParamTypeClasses #-}

module Data.TrieMap.Regular.CompMap () where

import Data.TrieMap.Regular.Base
import Data.TrieMap.Regular.Class
import Data.TrieMap.Regular.Ord
import Data.TrieMap.Regular.Eq
-- import Data.TrieMap.Regular.TH
import Data.TrieMap.TrieKey
-- import Data.TrieMap.Rep
-- import Data.TrieMap.Rep.TH

import Control.Applicative
import Control.Arrow

import Prelude hiding (lookup)

newtype CompMap m g k a = CMap (m (App g k) a)
newtype App f a = A {unA :: f a}
newtype AppMap m k a = AMap (m k a)

type instance TrieMapT (App f) = AppMap (TrieMapT f)
type instance TrieMap (App f r) = AppMap (TrieMapT f) r
type instance TrieMapT (f `O` g) = CompMap (TrieMapT f) g
type instance TrieMap ((f `O` g) r) = CompMap (TrieMapT f) g r

instance EqT f => EqT (App f) where
	eqT0 (==) (A a) (A b) = eqT0 (==) a b

instance OrdT f => OrdT (App f) where
	compareT0 cmp (A a) (A b) = compareT0 cmp a b

instance (EqT f, Eq r) => Eq (App f r) where
	(==) = eqT

instance (OrdT f, Ord g) => Ord (App f g) where
	compare = compareT

instance (TrieKeyT f m, Functor f, TrieKeyT g (TrieMapT g), TrieKey k (TrieMap k)) => 
		TrieKey ((f `O` g) k) (CompMap m g k) where
	emptyM = emptyT
	nullM = nullT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	alterM = alterT
	alterLookupM = alterLookupT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractM = extractT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT


instance (TrieKeyT f m, Functor f, TrieKeyT g (TrieMapT g)) => TrieKeyT (f `O` g) (CompMap m g) where
	emptyT = CMap emptyT
	nullT (CMap m) = nullT m
	sizeT s (CMap m) = sizeT s m
	lookupT (O x) (CMap m) = lookupT (A <$> x) m
	lookupIxT s (O x) (CMap m) = onKey (O . fmap unA) (lookupIxT s (A <$> x) m)
	assocAtT s i (CMap m) = onKey (O . fmap unA) (assocAtT s i m)
-- 	updateAtT s r f i (CMap m)
-- 		= CMap (updateAtT s r (\ i' -> f i' . O . fmap unA) i m)
	alterT s f (O x) (CMap m) = CMap (alterT s f (A <$> x) m)
	alterLookupT s f (O x) (CMap m) = CMap <$> alterLookupT s f (A <$> x) m
	traverseWithKeyT s f (CMap m) = CMap <$> traverseWithKeyT s (f . O . fmap unA) m
	foldWithKeyT f (CMap m) = foldWithKeyT (f . O . fmap unA) m
	foldlWithKeyT f (CMap m) = foldlWithKeyT (f . O . fmap unA) m
	mapEitherT s1 s2 f (CMap m) = (CMap *** CMap) (mapEitherT s1 s2 (f . O . fmap unA) m)
	splitLookupT s f (O k) (CMap m) = CMap `sides` splitLookupT s f (A <$> k) m
	isSubmapT (<=) (CMap m1) (CMap m2) = isSubmapT (<=) m1 m2
	extractT s f (CMap m) = fmap CMap <$> extractT s (f . O . fmap unA) m
-- 	extractMinT s f (CMap m) = second CMap <$> extractMinT s (f . O . fmap unA) m
-- 	extractMaxT s f (CMap m) = second CMap <$> extractMaxT s (f . O . fmap unA) m
-- 	alterMinT s f (CMap m) = CMap (alterMinT s (f . O . fmap unA) m)
-- 	alterMaxT s f (CMap m) = CMap (alterMaxT s (f . O . fmap unA) m)
	unionT s f (CMap m1) (CMap m2) = CMap (unionT s (f . O . fmap unA) m1 m2)
	isectT s f (CMap m1) (CMap m2) = CMap (isectT s (f . O . fmap unA) m1 m2)
	diffT s f (CMap m1) (CMap m2) = CMap (diffT s (f . O . fmap unA) m1 m2)

instance (TrieKeyT f m, TrieKey k (TrieMap k)) => TrieKey (App f k) (AppMap m k) where
	emptyM = emptyT
	nullM = nullT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	alterM = alterT
	alterLookupM = alterLookupT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractM = extractT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT

instance TrieKeyT f m => TrieKeyT (App f) (AppMap m) where
	emptyT = AMap emptyT
	nullT (AMap m) = nullT m
	sizeT s (AMap m) = sizeT s m
	lookupT (A k) (AMap m) = lookupT k m
	lookupIxT s (A k) (AMap m) = onKey A (lookupIxT s k m)
	assocAtT s i (AMap m) = onKey A (assocAtT s i m)
-- 	updateAtT s r f i (AMap m) = AMap (updateAtT s r (\ i' -> f i' . A) i m)
	alterT s f (A k) (AMap m) = AMap (alterT s f k m)
	alterLookupT s f (A k) (AMap m) = AMap <$> alterLookupT s f k m
	traverseWithKeyT s f (AMap m) = AMap <$> traverseWithKeyT s (f . A) m
	foldWithKeyT f (AMap m) = foldWithKeyT (f . A) m
	foldlWithKeyT f (AMap m) = foldlWithKeyT (f . A) m
	mapEitherT s1 s2 f (AMap m) = (AMap *** AMap) (mapEitherT s1 s2 (f . A) m)
	splitLookupT s f (A k) (AMap m) = AMap `sides` splitLookupT s f k m
	extractT s f (AMap m) = fmap AMap <$> extractT s (f . A) m
-- 	extractMinT s f (AMap m) = second AMap <$> extractMinT s (f . A) m
-- 	extractMaxT s f (AMap m) = second AMap <$> extractMaxT s (f . A) m
-- 	alterMinT s f (AMap m) = AMap (alterMinT s (f . A) m)
-- 	alterMaxT s f (AMap m) = AMap (alterMaxT s (f . A) m)
	unionT s f (AMap m1) (AMap m2) = AMap (unionT s (f . A) m1 m2)
	isectT s f (AMap m1) (AMap m2) = AMap (isectT s (f . A) m1 m2)
	diffT s f (AMap m1) (AMap m2) = AMap (diffT s (f . A) m1 m2)
	isSubmapT (<=) (AMap m1) (AMap m2) = isSubmapT (<=) m1 m2