module Data.Set.StringSet where
import Data.Binary
import Control.Monad

-- | StringSet is ternary tree. It is commonly used for storing word lists
-- like dictionaries for spell checking etc.
data StringSet = SNode !Char !StringSet !StringSet !StringSet | SEnd
               deriving (Show, Eq)

-- | Inserts a new list of elements into a tree.
insert :: String -> StringSet -> StringSet
-- General case
insert xss@(x:xs) (SNode ele l e h) =
    case compare x ele of
        LT -> SNode ele (insert xss l) e h
        EQ -> SNode ele l (insert xs e) h
        GT -> SNode ele l e (insert xss h)
-- Insert new elements quickly
insert xss@(x:xs) SEnd =
    insert' xss
-- SEnd of word in non empty tree
insert [] t@(SNode ele l e h) = 
    case compare '\0' ele of
        EQ -> t
        LT  -> SNode ele (insert [] l) e h
-- SEnd of word in empty tree
insert [] SEnd =
    SNode '\0' SEnd SEnd SEnd

-- | Quickly build an initial tree.
insert' :: String -> StringSet
insert' (x:xs) = SNode x SEnd (insert' xs) SEnd
insert' []     = SNode '\0' SEnd SEnd SEnd

-- | Returns true if the string is in the StringSet
isElem :: String -> StringSet -> Bool
isElem          _ SEnd              = False
isElem         [] (SNode ele l e h) = ele == '\0' || isElem [] l
isElem xss@(x:xs) (SNode ele l e h) = 
    case compare x ele of
        LT -> isElem xss l
        EQ -> isElem  xs e
        GT -> isElem xss h

-- | Returns the number of non-Null Elems
treeSize :: StringSet -> Int
treeSize SEnd = 0
treeSize (SNode '\0' l e h) = treeSize l + treeSize e + treeSize h
treeSize (SNode _ l e h) = 1 + treeSize l + treeSize e + treeSize h

-- | Counts how many entries there are in the tree.
numEntries :: StringSet -> Int
numEntries SEnd = 0
numEntries (SNode '\0' l e h) = 1 + numEntries l + numEntries e + numEntries h
numEntries (SNode _ l e h) = numEntries l + numEntries e + numEntries h

-- | Creates a new tree from a list of 'strings'
fromList :: [String] -> StringSet
fromList = foldl (flip insert) SEnd

-- | A rather long Binary instance, but uses binary numbers to indicate
-- where SEnds are efficiently.
instance Binary StringSet where
    put (SNode ch SEnd SEnd SEnd) = do
        putWord8 0
        put ch
    put (SNode ch SEnd SEnd h) = do
        putWord8 1
        put ch
        put h
    put (SNode ch SEnd e SEnd) = do
        putWord8 2
        put ch
        put e
    put (SNode ch SEnd e h) = do
        putWord8 3
        put ch
        put e
        put h
    put (SNode ch l SEnd SEnd) = do
        putWord8 4
        put ch
        put l
    put (SNode ch l SEnd h) = do
        putWord8 5
        put ch
        put l
        put h
    put (SNode ch l e SEnd) = do
        putWord8 6
        put ch
        put l
        put e
    -- General case
    put (SNode ch l e h) = do
        putWord8 7
        put ch
        put l
        put e
        put h
    put SEnd = putWord8 8
    get = do
        tag <- getWord8
        case tag of
            8 -> return SEnd
            _ -> do
                ch <- get
                -- [h,e,l] <- forM [0..2] $ \b -> if (tag `testBit` b) then get else return SEnd
                -- return (SNode ch l e h)
                case tag of
                    0 -> return (SNode ch SEnd SEnd SEnd)
                    1 -> do
                        h <- get
                        return (SNode ch SEnd SEnd h)
                    2 -> do
                        e <- get
                        return (SNode ch SEnd e SEnd)
                    3 -> do
                        e <- get
                        h <- get
                        return (SNode ch SEnd e h)
                    4 -> do
                        l <- get
                        return (SNode ch l SEnd SEnd)
                    5 -> do
                        l <- get
                        h <- get
                        return (SNode ch l SEnd h)
                    6 -> do
                        l <- get
                        e <- get
                        return (SNode ch l e SEnd)
                    7 -> do
                        l <- get
                        e <- get
                        h <- get
                        return (SNode ch l e h)