{-# LANGUAGE TypeFamilies, MultiParamTypeClasses, Rank2Types, FlexibleInstances, FlexibleContexts, UndecidableInstances #-}

module Data.TrieMap.MultiRec.FamMap where

import Data.TrieMap.MultiRec.Class
import Data.TrieMap.MultiRec.Eq
import Data.TrieMap.MultiRec.Ord
import Data.TrieMap.MultiRec.Sized
import Data.TrieMap.Sized
import Data.TrieMap.Applicative
import Data.TrieMap.TrieKey

import Control.Applicative
import Control.Arrow

import Data.Maybe
import Data.Foldable
import Data.Sequence ((|>))
import qualified Data.Sequence as Seq

import Generics.MultiRec

newtype Family phi ix = F ix
newtype FamMap (phi :: * -> *) m (a :: * -> *) ix = FamMap (m (Family phi) a ix)
type instance HTrieMap phi (Family phi) = FamMap phi (HTrieMapT phi (PF phi))

instance (Fam phi, HEq phi (PF phi), HFunctor phi (PF phi)) => HEq0 phi (Family phi) where
	heqH pf (F x) (F y) = heqT pf (from' pf x) (from' pf y)

instance (Fam phi, HOrd phi (PF phi), HFunctor phi (PF phi)) => HOrd0 phi (Family phi) where
	compareH0 pf (F x) (F y) = hcompare pf (from' pf x) (from' pf y)

instance (El phi ix, Fam phi, HEq phi (PF phi), HFunctor phi (PF phi)) => Eq (Family phi ix) where
	x == y = heqH (prove x) x y

instance (El phi ix, Fam phi, HOrd phi (PF phi), HFunctor phi (PF phi)) => Ord (Family phi ix) where
	x `compare` y = compareH0 (prove x) x y

prove :: El phi ix => Family phi ix -> phi ix
prove _ = proof

from' :: (Fam phi, HFunctor phi (PF phi)) => phi ix -> ix -> PF phi (Family phi) ix
from' pf = hmap (const (F . unI0)) pf . from pf

to' :: (Fam phi, HFunctor phi (PF phi)) => phi ix -> PF phi (Family phi) ix -> ix
to' pf = to pf . hmap (\ _ (F x) -> I0 x) pf

push :: (Fam phi, HFunctor phi (PF phi)) => phi ix -> (Family phi ix -> a) -> PF phi (Family phi) ix -> a
push pf f = f . F . to' pf

instance (Fam phi, HFunctor phi (PF phi), HTrieKeyT phi (PF phi) m) => HTrieKey phi (Family phi) (FamMap phi m) where
	emptyH pf = FamMap (emptyT pf)
	nullH pf (FamMap m) = nullT pf m
	sizeH s (FamMap m) = sizeT s m
	lookupH pf (F k) (FamMap m) = lookupT pf (from' pf k) m
	lookupIxH pf s (F k) (FamMap m) = lookupIxT pf s (from' pf k) m
	assocAtH pf s i (FamMap m) = case assocAtT pf s i m of
		(i, k, a) -> (i, F (to' pf k), a)
	updateAtH pf s f i (FamMap m) = FamMap (updateAtT pf s (\ i -> f i . F . to' pf) i m)
	alterH pf s f (F k) (FamMap m) = FamMap (alterT pf s f (from' pf k) m)
	traverseWithKeyH pf s f (FamMap m) =
		FamMap <$> traverseWithKeyT pf s (push pf f) m
	foldWithKeyH pf f (FamMap m) = foldWithKeyT pf (push pf f) m
	foldlWithKeyH pf f (FamMap m) = foldlWithKeyT pf (push pf f) m
	mapEitherH pf s1 s2 f (FamMap m) = (FamMap *** FamMap) (mapEitherT pf s1 s2 (push pf f) m)
	splitLookupH pf s f (F k) (FamMap m) = FamMap `sides` splitLookupT pf s f (from' pf k) m
	unionH pf s f (FamMap m1) (FamMap m2) = FamMap (unionT pf s (push pf f) m1 m2)
	isectH pf s f (FamMap m1) (FamMap m2) = FamMap (isectT pf s (push pf f) m1 m2)
	diffH pf s f (FamMap m1) (FamMap m2) = FamMap (diffT pf s (push pf f) m1 m2)
	extractMinH pf s (FamMap m) = do
		((k, a), m') <- extractMinT pf s m
		return ((F (to' pf k), a), FamMap m')
	extractMaxH pf s (FamMap m) = do
		((k, a), m') <- extractMaxT pf s m
		return ((F (to' pf k), a), FamMap m')
	alterMinH pf s f (FamMap m) = FamMap (alterMinT pf s (push pf f) m)
	alterMaxH pf s f (FamMap m) = FamMap (alterMaxT pf s (push pf f) m)
	isSubmapH pf (<=) (FamMap m1) (FamMap m2) = isSubmapT pf (<=) m1 m2
	fromListH pf s f xs = FamMap (fromListT pf s (push pf f) [(from' pf k, a) | (F k, a) <- xs])
	fromAscListH pf s f xs = FamMap (fromAscListT pf s (push pf f) [(from' pf k, a) | (F k, a) <- xs])
	fromDistAscListH pf s xs = FamMap (fromDistAscListT pf s [(from' pf k, a) | (F k, a) <- xs])

-- type family UniqueFam ix :: * -> *
newtype FMap (phi :: * -> *) m xi a ix = FMap (m (I ix a) xi)
type instance TrieMap (Family phi ix) = FMap phi (HTrieMap phi (Family phi)) ix

sizeI :: Sized a -> HSized phi (I ix a)
sizeI s (I a) = s a

instance (El phi ix, Fam phi, HFunctor phi (PF phi), HTrieKey phi (Family phi) m, m ~ HTrieMap phi (Family phi),
		HOrd phi (PF phi)) => TrieKey (Family phi ix) (FMap phi m ix) where
	emptyM = FMap (emptyH proof)
	nullM (FMap m) = nullH proof m
	sizeM s (FMap m) = sizeH (sizeI s) m
	lookupM k (FMap m) = unI <$> lookupH proof k m
	lookupIxM s k (FMap m) = fmap unI <$> lookupIxH proof (sizeI s) k m
	assocAtM s i (FMap m) = case assocAtH proof (sizeI s) i m of
		(i, k, I a) -> (i, k, a)
	updateAtM s f i (FMap m) = FMap (updateAtH proof (sizeI s) (\ i' k (I a) -> I <$> f i' k a) i m)
	alterM s f k (FMap m) = FMap (alterH proof (sizeI s) (fmap I . f . fmap unI) k m)
	traverseWithKeyM s f (FMap m) = FMap <$> traverseWithKeyH proof (sizeI s) (\ k (I a) -> I <$> f k a) m
	foldWithKeyM f (FMap m) = foldWithKeyH proof (\ k (I a) -> f k a) m
	foldlWithKeyM f (FMap m) = foldlWithKeyH proof (\ k z (I a) -> f k z a) m
	mapEitherM s1 s2 f (FMap m) = 
		(FMap *** FMap) (mapEitherH proof (sizeI s1) (sizeI s2) (\ k (I a) -> (fmap I *** fmap I) (f k a)) m)
	splitLookupM s f k (FMap m) = FMap `sides` splitLookupH proof (sizeI s) (sides (I <$>) . f . unI) k m
	unionM s f (FMap m1) (FMap m2) = FMap (unionH proof (sizeI s) f' m1 m2) where
		f' k (I x) (I y) = I <$> f k x y
	isectM s f (FMap m1) (FMap m2) = FMap (isectH proof (sizeI s) f' m1 m2) where
		f' k (I x) (I y) = I <$> f k x y
	diffM s f (FMap m1) (FMap m2) = FMap (diffH proof (sizeI s) f' m1 m2) where
		f' k (I x) (I y) = I <$> f k x y
	extractMinM s (FMap m) = do
		((k, I a), m') <- extractMinH proof (sizeI s) m
		return ((k, a), FMap m')
	extractMaxM s (FMap m) = do
		((k, I a), m') <- extractMaxH proof (sizeI s) m
		return ((k, a), FMap m')
	alterMinM s f (FMap m) = FMap (alterMinH proof (sizeI s) (\ k (I a) -> I <$> f k a) m)
	alterMaxM s f (FMap m) = FMap (alterMaxH proof (sizeI s) (\ k (I a) -> I <$> f k a) m)
	isSubmapM (<=) (FMap m1) (FMap m2) = isSubmapH proof (<<=) m1 m2 where
		I a <<= I b = a <= b
	fromListM s f xs = FMap (fromListH proof (sizeI s) (\ k (I a) (I b) -> I (f k a b)) [(k, I a) | (k, a) <- xs])
	fromAscListM s f xs = FMap (fromAscListH proof (sizeI s) (\ k (I a) (I b) -> I (f k a b)) [(k, I a) | (k, a) <- xs])
	fromDistAscListM s xs = FMap (fromDistAscListH proof (sizeI s) [(k, I a) | (k, a) <- xs])