{-# LANGUAGE PatternGuards, Rank2Types, FlexibleContexts, MultiParamTypeClasses, FunctionalDependencies, TypeFamilies, KindSignatures #-}

module Data.TrieMap.TrieKey where

import Data.TrieMap.Applicative
import Data.TrieMap.Sized
import Data.TrieMap.CPair

import Control.Applicative
import Control.Arrow

import Data.Monoid
import Data.List

type family TrieMap k :: * -> *

-- type family MapPF (m :: (* -> *) -> * -> *) ix :: (* -> *) -> *
-- data Fixer f

type EitherMap k a b c = k -> a -> (Maybe b, Maybe c)
type SplitMap a x = a -> (Maybe a, Maybe x, Maybe a)
type UnionFunc k a = k -> a -> a -> Maybe a
type IsectFunc k a b c = k -> a -> b -> Maybe c
type DiffFunc k a b = k -> a -> b -> Maybe a
type ExtractFunc f m k a x = (k -> a -> f (CPair x (Maybe a))) -> m -> f (CPair x m)
type LEq a b = a -> b -> Bool

data Assoc k a = Asc {-# UNPACK #-} !Int k a
-- data IndexPos k a = Between {-# UNPACK #-} !(Assoc k a) {-# UNPACK #-} !(Assoc k a)
-- 			| Exact {-# UNPACK #-} !(Assoc k a) (Last (Assoc k a)) (First (Assoc k a))
-- 			| Above {-# UNPACK #-} !(Assoc k a) | Below {-# UNPACK #-} !(Assoc k a) | Nada
type IndexPos k a = (Last (Assoc k a), Maybe (Assoc k a), First (Assoc k a))

onIndexA :: (Int -> Int) -> Assoc k a -> Assoc k a
onIndexA f (Asc i k a) = Asc (f i) k a

onIndex :: (Int -> Int) -> IndexPos k a -> IndexPos k a
onIndex f (l, x, r) = (onIndexA f <$> l, onIndexA f <$> x, onIndexA f <$> r)

onKey :: (k -> k') -> IndexPos k a -> IndexPos k' a
onKey = onValue . first

onVal :: (a -> a') -> IndexPos k a -> IndexPos k a'
onVal = onValue . second

onKeyA :: (k -> k') -> Assoc k a -> Assoc k' a
onKeyA = onValueA . first

onValA :: (a -> a') -> Assoc k a -> Assoc k a'
onValA = onValueA . second

{-# INLINE onValueA #-}
onValueA :: ((k, a) -> (k', a')) -> Assoc k a -> Assoc k' a'
onValueA f (Asc i k a) = uncurry (Asc i) (f (k, a))

{-# INLINE onValue #-}
onValue :: ((k, a) -> (k', a')) -> IndexPos k a -> IndexPos k' a'
onValue f (l, x, r) = (onValueA f <$> l, onValueA f <$> x, onValueA f <$> r)

type Round = Bool
-- type Sized f = forall ix . f ix -> Int

-- toFixer :: a -> Fixer a
-- toFixer _ = undefined

class Ord k => TrieKey k m | m -> k where
	emptyM :: TrieMap k ~ m => m a
	nullM :: TrieMap k ~ m => m a -> Bool
	sizeM :: (TrieMap k ~ m) => Sized a -> m a -> Int
	lookupM :: TrieMap k ~ m => k -> m a -> Maybe (a)
	lookupIxM :: TrieMap k ~ m => Sized a -> k -> m a -> IndexPos k a
	assocAtM :: TrieMap k ~ m => Sized a -> Int -> m a -> IndexPos k a
-- 	updateAtM :: TrieMap k ~ m => Sized a -> Round -> (Int -> k -> a -> Maybe (a)) -> Int -> m a -> m a
	alterM :: (TrieMap k ~ m) => Sized a -> (Maybe (a) -> Maybe (a)) -> k -> m a -> m a
	alterLookupM :: TrieMap k ~ m => Sized a -> (Maybe a -> CPair x (Maybe a)) -> k -> m a -> CPair x (m a)
	{-# SPECIALIZE traverseWithKeyM :: (k -> a -> Id (b)) -> m a -> Id (m b) #-}
	traverseWithKeyM :: (TrieMap k ~ m, Applicative f) => Sized b ->
		(k -> a -> f (b)) -> m a -> f (m b)
	foldWithKeyM :: TrieMap k ~ m => (k -> a -> b -> b) -> m a -> b -> b
	foldlWithKeyM :: TrieMap k ~ m => (k -> b -> a -> b) -> m a -> b -> b
	mapEitherM :: (TrieMap k ~ m) => Sized b -> Sized c -> EitherMap k (a) (b) (c) -> m a -> (m b, m c)
	splitLookupM :: (TrieMap k ~ m) => Sized a -> SplitMap (a) x -> k -> m a -> (m a, Maybe x, m a)
	unionM :: (TrieMap k ~ m) => Sized a -> UnionFunc k (a) -> m a -> m a -> m a
	isectM :: (TrieMap k ~ m) => Sized c -> IsectFunc k (a) (b) (c) -> m a -> m b -> m c
	diffM :: (TrieMap k ~ m) => Sized a -> DiffFunc k (a) (b) -> m a -> m b -> m a
	extractM :: (TrieMap k ~ m, Alternative f) => Sized a -> ExtractFunc f (m a) k a x
-- 	extractMinM :: (TrieMap k ~ m) => Sized a -> ExtractFunc k First (a) (m a) x
-- 	extractMaxM :: (TrieMap k ~ m) => Sized a -> ExtractFunc k Last (a) (m a) x
-- 	alterMinM :: (TrieMap k ~ m) => Sized a -> (k -> a -> Maybe a) -> m a -> First (m a)
-- 	alterMaxM :: (TrieMap k ~ m) => Sized a -> (k -> a -> Maybe a) -> m a -> Last (m a)
	isSubmapM :: TrieMap k ~ m => LEq (a) (b) -> LEq (m a) (m b)
	fromListM, fromAscListM :: (TrieMap k ~ m) => Sized a -> (k -> a -> a -> a) -> [(k, a)] -> m a
	fromDistAscListM :: (TrieMap k ~ m) => Sized a -> [(k, a)] -> m a
	
-- 	alterLookupM s f k m = fmap (\ v' -> alterM s (const v') k m) (f (lookupM k m))
	alterM s f k m = cpSnd (alterLookupM s (cP () . f) k m)
	sizeM s m = foldWithKeyM (\ _ a n -> s a + n) m 0
	fromListM s f = foldl' (flip (uncurry (insertWithKeyM s f))) emptyM
	fromAscListM = fromListM
	fromDistAscListM s = fromAscListM s (const const)

guardNullM :: (TrieKey k m, m ~ TrieMap k) => m a -> Maybe (m a)
guardNullM m
	| nullM m	= Nothing
	| otherwise	= Just m

sides :: (a -> c) -> (a, b, a) -> (c, b, c)
sides f (l, x, r) = (f l, x, f r)

mapMaybeM :: (TrieKey k m, m ~ TrieMap k) => Sized b -> (k -> a -> Maybe (b)) -> m a -> m b
mapMaybeM s f = snd . mapEitherM elemSize s (((,) (Nothing :: Maybe (Elem ix))) .: f)

{-# INLINE [1] mapWithKeyM #-}
mapWithKeyM :: (TrieKey k m, m ~ TrieMap k) => Sized b -> (k -> a -> b) -> m a -> m b
mapWithKeyM s f  = unId . traverseWithKeyM s (Id .: f)

mapM :: (TrieKey k m, m ~ TrieMap k) => Sized b -> (a -> b) -> m a -> m b
mapM s = mapWithKeyM s . const

assocsM :: (TrieKey k m, m ~ TrieMap k) => m a -> [(k, a)]
assocsM m = foldWithKeyM (\ k a xs -> (k, a):xs) m []

insertM :: (TrieKey k m, m ~ TrieMap k) => Sized a -> k -> a -> m a -> m a
insertM s = insertWithKeyM s (const const)

insertWithKeyM :: (TrieKey k m, m ~ TrieMap k) => Sized a -> (k -> a -> a -> a) -> k -> a -> m a -> m a
insertWithKeyM s f k a = alterM s f' k where
	f' = Just . maybe a (f k a)

singletonM :: (TrieKey k m, m ~ TrieMap k) => Sized a -> k -> a -> m a
singletonM s k a = insertM s k a emptyM

fromListM' :: (TrieKey k m, m ~ TrieMap k) => Sized a -> [(k, a)] -> m a
fromListM' s = fromListM s (const const) --xs = foldr (uncurry insertM) emptyM xs

unionMaybe :: (a -> a -> Maybe a) -> Maybe a -> Maybe a -> Maybe a
unionMaybe _ Nothing y = y
unionMaybe _ x Nothing = x
unionMaybe f (Just x) (Just y) = f x y

isectMaybe :: (a -> b -> Maybe c) -> Maybe a -> Maybe b -> Maybe c
isectMaybe f (Just x) (Just y) = f x y
isectMaybe _ _ _ = Nothing

diffMaybe :: (a -> b -> Maybe a) -> Maybe a -> Maybe b -> Maybe a
diffMaybe f Nothing = const Nothing
diffMaybe f (Just x) = maybe (Just x) (f x)

subMaybe :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
subMaybe _ Nothing _ = True
subMaybe (<=) (Just a) (Just b) = a <= b
subMaybe _ _ _ = False

aboutM :: (TrieKey k (TrieMap k), Alternative t) => (k -> a -> t z) -> TrieMap k a -> t z
aboutM f = cpFst <.> extractM (const 0) (\ k a -> fmap (flip cP Nothing) (f k a))

{-# RULES
-- 	"lookupM/emptyM" forall k . lookupM k emptyM = Nothing;
-- 	"sizeM/emptyM" forall s . sizeM s emptyM = 0;
-- 	"traverseWithKeyM/emptyM" forall s f . traverseWithKeyM s f emptyM = pure emptyM;
-- 	"extractM/emptyM" forall s f . extractM s f emptyM = empty;
-- 	"foldWithKeyM/emptyM" forall f . foldWithKeyM f emptyM z = z;
-- 	"foldlWithKeyM/emptyM" forall f . foldlWithKeyM f emptyM z = z;
-- 	"lookupIxM/emptyM" forall s k . lookupIxM s k emptyM = (empty, empty, empty);
-- 	"mapEitherM/emptyM" forall s1 s2 f . mapEitherM s1 s2 f emptyM = (emptyM, emptyM);
	#-}