{-# LANGUAGE KindSignatures, TypeFamilies, MultiParamTypeClasses, FlexibleContexts, FlexibleInstances, UndecidableInstances #-}

module Data.TrieMap.MultiRec.ConstMap where

import Data.TrieMap.MultiRec.Class
import Data.TrieMap.MultiRec.Eq
import Data.TrieMap.MultiRec.Sized
import Data.TrieMap.Applicative
import Data.TrieMap.TrieKey

import Control.Applicative
import Control.Arrow
import Control.Monad

import Data.Maybe
import Data.Foldable
import Generics.MultiRec

newtype KMap (phi :: * -> *) m (r :: * -> *) (a :: * -> *) ix = KMap (m a ix)
type instance HTrieMapT phi (K k) = KMap phi (TrieMap k)
type instance HTrieMap phi (K k r) = HTrieMapT phi (K k) r

instance TrieKey k m => HTrieKeyT phi (K k) (KMap phi m) where
	emptyT = emptyH
	nullT = nullH
	sizeT = sizeH
	lookupT = lookupH
	lookupIxT = lookupIxH
	assocAtT = assocAtH
	updateAtT = updateAtH
	alterT = alterH
	traverseWithKeyT = traverseWithKeyH
	foldWithKeyT = foldWithKeyH
	foldlWithKeyT = foldlWithKeyH
	mapEitherT = mapEitherH
	splitLookupT = splitLookupH
	unionT = unionH
	isectT = isectH
	diffT = diffH
	extractMinT = extractMinH
	extractMaxT = extractMaxH
	alterMinT = alterMinH
	alterMaxT = alterMaxH
	isSubmapT = isSubmapH
	fromListT = fromListH
	fromAscListT = fromAscListH
	fromDistAscListT = fromDistAscListH

instance TrieKey k m => HTrieKey phi (K k r) (KMap phi m r) where
	emptyH _ = KMap emptyM
	nullH _ (KMap m) = nullM m
	sizeH s (KMap m) = sizeM (s) m
	lookupH _ (K k) (KMap m) = lookupM k m
	lookupIxH _ s (K k) (KMap m) = lookupIxM s k m
	assocAtH _ s i (KMap m) = case assocAtM s i m of
		(i, k, a) -> (i, K k, a)
	updateAtH _ s f i (KMap m) = KMap (updateAtM s (\ i -> f i . K) i m)
	alterH pf s f (K k) (KMap m) = KMap (alterM (s) f k m)
	traverseWithKeyH pf s f (KMap m) = KMap <$> traverseWithKeyM (s) (f . K) m
	foldWithKeyH _ f (KMap m) = foldWithKeyM (f . K) m
	foldlWithKeyH _ f (KMap m) = foldlWithKeyM (f . K) m
	mapEitherH pf s1 s2 f (KMap m) = (KMap *** KMap) (mapEitherM (s1) (s2) (f . K) m)
	splitLookupH pf s f (K k) (KMap m) = KMap `sides` splitLookupM (s) f k m
	unionH pf s f (KMap m1) (KMap m2) = KMap (unionM (s) (f . K) m1 m2)
	isectH pf s f (KMap m1) (KMap m2) = KMap (isectM (s) (f . K) m1 m2)
	diffH pf s f (KMap m1) (KMap m2) = KMap (diffM (s) (f . K) m1 m2)
	extractMinH pf s (KMap m) = do
		((k, a), m') <- extractMinM (s) m
		return ((K k, a), KMap m')
	extractMaxH pf s (KMap m) = do
		((k, a), m') <- extractMaxM (s) m
		return ((K k, a), KMap m')
	alterMinH pf s f (KMap m) = KMap (alterMinM (s) (f . K) m)
	alterMaxH pf s f (KMap m) = KMap (alterMaxM (s) (f . K) m)
	isSubmapH _ (<=) (KMap m1) (KMap m2) = isSubmapM (<=) m1 m2
	fromListH pf s f xs = KMap (fromListM (s) (f . K) [(k, a) | (K k, a) <- xs])
	fromAscListH pf s f xs = KMap (fromAscListM (s) (f . K) [(k, a) | (K k, a) <- xs])
	fromDistAscListH pf s xs = KMap (fromDistAscListM (s) [(k, a) | (K k, a) <- xs])