module Data.Set.StringSet where
import Data.Binary
import Control.Monad

-- | StringSet is ternary tree. It is commonly used for storing word lists
-- like dictionaries for spell checking etc.
data StringSet = SNode !Char !StringSet !StringSet !StringSet | SEnd
               deriving (Show, Eq)

-- | Inserts a new list of elements into a tree.
insert :: String -> StringSet -> StringSet
-- General case
insert xss@(x:xs) (SNode ele l e h) =
    case compare x ele of
        LT -> SNode ele (insert xss l) e h
        EQ -> SNode ele l (insert xs e) h
        GT -> SNode ele l e (insert xss h)
-- Insert new elements quickly
insert xss@(x:xs) SEnd =
    insert' xss
-- SEnd of word in non empty tree
insert [] t@(SNode ele l e h) = 
    case compare '\0' ele of
        EQ -> t
        LT  -> SNode ele (insert [] l) e h
-- SEnd of word in empty tree
insert [] SEnd =
    SNode '\0' SEnd SEnd SEnd

-- | Quickly build an initial tree.
insert' :: String -> StringSet
insert' (x:xs) = SNode x SEnd (insert' xs) SEnd
insert' []     = SNode '\0' SEnd SEnd SEnd

-- | Returns true if the string is in the StringSet
isElem :: String -> StringSet -> Bool
isElem          _ SEnd              = False
isElem         [] (SNode ele l e h) = ele == '\0' || isElem [] l
isElem xss@(x:xs) (SNode ele l e h) = 
    case compare x ele of
        LT -> isElem xss l
        EQ -> isElem  xs e
        GT -> isElem xss h

-- | Returns the number of non-Null Elems
treeSize :: StringSet -> Int
treeSize SEnd = 0
treeSize (SNode '\0' l e h) = treeSize l + treeSize e + treeSize h
treeSize (SNode _ l e h) = 1 + treeSize l + treeSize e + treeSize h

-- | Counts how many entries there are in the tree.
numEntries :: StringSet -> Int
numEntries SEnd = 0
numEntries (SNode '\0' l e h) = 1 + numEntries l + numEntries e + numEntries h
numEntries (SNode _ l e h) = numEntries l + numEntries e + numEntries h

-- | Creates a new tree from a list of 'strings'
fromList :: [String] -> StringSet
fromList = foldl (flip insert) SEnd


instance Binary StringSet where
    put SEnd = put (0 :: Word8)
    -- Quite common, so speecialised
    put (SNode ch SEnd SEnd SEnd) = do
        putWord8 1
        put ch
    -- Also common, basically what insert' produces.
    put (SNode ch SEnd e SEnd) = do
        putWord8 2
        put ch
        put e
    -- General case
    put (SNode ch l e h) = do
        putWord8 3
        put ch
        put l
        put e
        put h
    get = do
        tag <- getWord8
        case tag of
            0 -> return SEnd
            1 -> do
                ch <- get
                return (SNode ch SEnd SEnd SEnd)
            2 -> do
                ch <- get
                e <- get
                return (SNode ch SEnd e SEnd)
            3 -> liftM4 SNode get get get get