{-# LANGUAGE TypeOperators, FlexibleContexts, StandaloneDeriving #-}

module TrieMap.MapTypes where

import Data.Foldable
import Data.Traversable
import Control.Applicative
import Prelude hiding (foldl, foldr)

-- | 'ProdMap' is used to hold a map on the product of two key types.
newtype ProdMap m1 m2 v = PMap {unPMap :: m1 (m2 v)} deriving (Eq, Ord)

-- | 'UnionMap' is used to hold a map on the sum of two key types.
data UnionMap m1 m2 v = m1 v :+: m2 v deriving (Eq, Ord)

data Edge k m v = Edge [k] (Maybe v) (m (Edge k m v))
type MEdge k m v = Maybe (Edge k m v)

-- | 'RadixTrie' is used to hold a map on a list of keys.
newtype RadixTrie k m v = Radix {unRad :: MEdge k m v} 

infixr 5 `ProdMap`
infixr 5 :+:

instance (Functor m1, Functor m2) => Functor (ProdMap m1 m2) where
	fmap f (PMap m) = PMap (fmap (fmap f) m)

instance (Foldable m1, Foldable m2) => Foldable (ProdMap m1 m2) where
	foldr f z (PMap m) = foldr (flip (foldr f)) z m
	foldl f z (PMap m) = foldl (foldl f) z m

instance (Traversable m1, Traversable m2) => Traversable (ProdMap m1 m2) where
	traverse f (PMap m) = PMap <$> traverse (traverse f) m

instance (Functor m1, Functor m2) => Functor (UnionMap m1 m2) where
	fmap f (m1 :+: m2) = fmap f m1 :+: fmap f m2

instance (Foldable m1, Foldable m2) => Foldable (UnionMap m1 m2) where
	foldr f z (m1 :+: m2) = foldr f (foldr f z m2) m1
	foldl f z (m1 :+: m2) = foldl f (foldl f z m1) m2

instance (Traversable m1, Traversable m2) => Traversable (UnionMap m1 m2) where
	traverse f (m1 :+: m2) = liftA2 (:+:) (traverse f m1) (traverse f m2)

instance Functor m => Functor (Edge k m) where
	fmap f (Edge ks v ts) = Edge ks (fmap f v) (fmap (fmap f) ts)

instance Functor m => Functor (RadixTrie k m) where
	fmap f (Radix e) = Radix (fmap (fmap f) e)

instance Foldable m => Foldable (Edge k m) where
	foldr f z (Edge _ v ts) = foldr (flip (foldr f)) (foldr f z v) ts
	foldl f z (Edge _ v ts) = foldl f (foldl (foldl f) z ts) v

instance Foldable m => Foldable (RadixTrie k m) where
	foldr f z (Radix e) = foldr (flip (foldr f)) z e
	foldl f z (Radix e) = foldl (foldl f) z e

instance Traversable m => Traversable (Edge k m) where
	traverse f (Edge ks v ts) = 
		liftA2 (Edge ks) (traverse f v) (traverse (traverse f) ts)

instance Traversable m => Traversable (RadixTrie k m) where
	traverse f (Radix e) = Radix <$> traverse (traverse f) e