module Data.Map.TernaryMap where
import Data.Binary
import Control.Monad


-- | Elem2 a b is used to hold elements of a list after insertion, and
-- indicate that we've reached the end of the list.
data Elem2 a b = C !a
               | Val b
             deriving (Show)
-- | TernaryMap a b is ternary tree. It is commonly used for storing word lists
-- like dictionaries.
data TernaryMap a b = TNode !(Elem2 a b) !(TernaryMap a b) !(TernaryMap a b) !(TernaryMap a b)
                    | TEnd
               deriving (Show, Eq)


instance Eq a => Eq (Elem2 a b) where
    (Val _) == (Val _) = True
    (Val _) == x        = False
    x        == (Val _) = False
    (C a)    == (C b)    = a == b

-- | All elements are greater than the Val Elem, otherwise they are
-- ordered according to their own ord instance (for the `compare (C x) (C y)` case).
instance (Ord a) => Ord (Elem2 a b) where
    compare (Val _) (Val _)   = EQ
    compare (Val _) x          = LT
    compare x        (Val _)   = GT
    compare (C x) (C y)         = compare x y

isVal (Val _) = True
isVal _        = False

-- | Quickly build a tree without an initial tree. This should be used
-- to create an initial tree, using insert there after.
insert' :: Ord a => [a] -> b -> TernaryMap a b
insert' (x:xs) b = TNode (C x) TEnd (insert' xs b) TEnd
insert' []     b = TNode (Val b) TEnd TEnd TEnd

-- | Inserts an entries into a tree. Values with the same key will be replaced
-- with the newer version.
insert :: Ord a => [a] -> b -> TernaryMap a b -> TernaryMap a b
-- General case
insert xss@(x:xs) b (TNode ele l e h) =
    case compare (C x) ele of
        LT -> TNode ele (insert xss b l) e h
        EQ -> TNode ele l (insert xs b e) h
        GT -> TNode ele l e (insert xss b h)
-- Insert new elements quickly
insert xss@(x:xs) b TEnd =
    insert' xss b
-- end of word in non empty tree
insert [] b (TNode ele l e h) = 
    case compare (Val b) ele of
        EQ -> TNode (Val b) l e h
        LT  -> TNode ele (insert [] b l) e h
-- end of word in empty tree
insert [] b TEnd =
    TNode (Val b) TEnd TEnd TEnd


-- | Returns true if the `[a]` is a key in the TernaryMap.
isKey :: Ord a => [a] -> TernaryMap a b -> Bool
isKey          _ TEnd              = False
isKey         [] (TNode ele l e h) = isVal ele || isKey [] l
isKey xss@(x:xs) (TNode ele l e h) = 
    case compare (C x) ele of
        LT -> isKey xss l
        EQ -> isKey  xs e
        GT -> isKey xss h

getVal :: Ord a => [a] -> TernaryMap a b -> Maybe b
getVal _ TEnd = Nothing
getVal [] (TNode (Val b) _ _ _) = Just b
getVal [] (TNode ele l _ _)     = getVal [] l
getVal xss@(x:xs) (TNode ele l e h) =
    case compare (C x) ele of
        LT -> getVal xss l
        EQ -> getVal  xs e
        GT -> getVal xss h

-- | Returns the number of non-Val Elems
treeSize :: TernaryMap a b -> Int
treeSize TEnd = 0
treeSize (TNode (Val _) l e h) = treeSize l + treeSize e + treeSize h
treeSize (TNode _ l e h) = 1 + treeSize l + treeSize e + treeSize h

-- | Counts how many entries there are in the tree.
numEntries :: TernaryMap a b -> Int
numEntries TEnd = 0
numEntries (TNode (Val _) l e h) = 1 + numEntries l + numEntries e + numEntries h
numEntries (TNode _ l e h) = numEntries l + numEntries e + numEntries h

-- | Creates a new tree from a list of 'strings'
fromList :: Ord a => [([a],b)] -> TernaryMap a b
fromList = foldl (\tree (as,b) -> insert as b tree) TEnd

instance (Binary a, Binary b) => Binary (Elem2 a b) where
    put (C x) = putWord8 0 >> put x
    put (Val b) = putWord8 1 >> put b
    get = do
        n <- getWord8
        case n of
            0 -> liftM C get
            1 -> liftM Val get

-- | This binary instance saves some space by making special cases
-- of some commonly encountered structures in the trees.
instance (Binary a, Binary b) => Binary (TernaryMap a b) where
    put TEnd = putWord8 0
    -- Quite common, so specialised
    put (TNode ch TEnd TEnd TEnd) = do
        putWord8 1
        put ch
    -- Also common, basically what insert' produces.
    put (TNode ch TEnd e TEnd) = do
        putWord8 2
        put ch
        put e
    -- General case
    put (TNode ch l e h) = do
        putWord8 3
        put ch
        put l
        put e
        put h
    get = do
        tag <- getWord8
        case tag of
            0 -> return TEnd
            1 -> do
                ch <- get
                return (TNode ch TEnd TEnd TEnd)
            2 -> do
                ch <- get
                e <- get
                return (TNode ch TEnd e TEnd)
            3 -> liftM4 TNode get get get get