module Data.Set.StringSet (
            StringSet,
            insert,
            singleton,
            member,
            size,
            fromList
            ) where
import Data.Binary
import Control.Monad

-- | StringSet is ternary tree. It is commonly used for storing word lists
-- like dictionaries for spell checking etc.
data StringSet = SNode {-# UNPACK #-} !Char !StringSet !StringSet !StringSet
               | SEnd
               deriving (Show, Eq)

-- | Inserts a new list of elements into a tree.
insert :: String -> StringSet -> StringSet
-- General case
insert xss@(x:xs) (SNode ele l e h) =
    case compare x ele of
        LT -> SNode ele (insert xss l) e h
        EQ -> SNode ele l (insert xs e) h
        GT -> SNode ele l e (insert xss h)
-- Insert new elements quickly
insert xss@(x:xs) SEnd =
    singleton xss
-- SEnd of word in non empty tree
insert [] t@(SNode ele l e h) = 
    case compare '\0' ele of
        EQ -> t
        LT  -> SNode ele (insert [] l) e h
-- SEnd of word in empty tree
insert [] SEnd =
    SNode '\0' SEnd SEnd SEnd

-- | Quickly build an initial tree.
singleton :: String -> StringSet
singleton (x:xs) = SNode x SEnd (singleton xs) SEnd
singleton []     = SNode '\0' SEnd SEnd SEnd

-- | Returns true if the string is in the StringSet
member :: String -> StringSet -> Bool
member          _ SEnd              = False
member         [] (SNode ele l e h) = ele == '\0' || member [] l
member xss@(x:xs) (SNode ele l e h) = 
    case compare x ele of
        LT -> member xss l
        EQ -> member  xs e
        GT -> member xss h

-- | Returns the number of non-Null Elems
treeSize :: StringSet -> Int
treeSize SEnd = 0
treeSize (SNode '\0' l e h) = treeSize l + treeSize e + treeSize h
treeSize (SNode _ l e h) = 1 + treeSize l + treeSize e + treeSize h

-- | Counts how many entries there are in the tree.
size :: StringSet -> Int
size SEnd = 0
size (SNode '\0' l e h) = 1 + size l + size e + size h
size (SNode _ l e h) = size l + size e + size h

-- | Creates a new tree from a list of 'strings'
fromList :: [String] -> StringSet
fromList = foldl (flip insert) SEnd

-- | An empty set.
empty :: StringSet
empty = SEnd

-- | Returns true if the set is empty.
null :: StringSet -> Bool
null SEnd = True
null _    = False

-- | A rather long Binary instance, that uses binary numbers to indicate
-- where SEnds are efficiently.
instance Binary StringSet where
    put (SNode ch SEnd SEnd SEnd) = do
        putWord8 0
        put ch
    put (SNode ch SEnd SEnd h) = do
        putWord8 1
        put ch
        put h
    put (SNode ch SEnd e SEnd) = do
        putWord8 2
        put ch
        put e
    put (SNode ch SEnd e h) = do
        putWord8 3
        put ch
        put e
        put h
    put (SNode ch l SEnd SEnd) = do
        putWord8 4
        put ch
        put l
    put (SNode ch l SEnd h) = do
        putWord8 5
        put ch
        put l
        put h
    put (SNode ch l e SEnd) = do
        putWord8 6
        put ch
        put l
        put e
    -- General case
    put (SNode ch l e h) = do
        putWord8 7
        put ch
        put l
        put e
        put h
    put SEnd = putWord8 8
    
    get = do
        tag <- getWord8
        case tag of
            8 -> return SEnd
            _ -> do
                ch <- get
                -- [h,e,l] <- forM [0..2] $ \b -> if (tag `testBit` b) then get else return SEnd
                -- return (SNode ch l e h)
                case tag of
                    0 -> return (SNode ch SEnd SEnd SEnd)
                    1 -> do
                        h <- get
                        return (SNode ch SEnd SEnd h)
                    2 -> do
                        e <- get
                        return (SNode ch SEnd e SEnd)
                    3 -> do
                        e <- get
                        h <- get
                        return (SNode ch SEnd e h)
                    4 -> do
                        l <- get
                        return (SNode ch l SEnd SEnd)
                    5 -> do
                        l <- get
                        h <- get
                        return (SNode ch l SEnd h)
                    6 -> do
                        l <- get
                        e <- get
                        return (SNode ch l e SEnd)
                    7 -> do
                        l <- get
                        e <- get
                        h <- get
                        return (SNode ch l e h)