module Data.Set.TernarySet where
import Data.Binary
import Control.Monad


-- | Elem a is used to hold elements of a list after insertion, and
-- indicate that we've reached the end of the list.
data Elem a = C !a
            | Null
             deriving (Show, Eq)
-- | TernarySet a is ternary tree. It is commonly used for storing word lists
-- like dictionaries.
data TernarySet a = TNode !(Elem a) !(TernarySet a) !(TernarySet a) !(TernarySet a)
                  | TEnd
               deriving (Show, Eq)

-- | All elements are greater than the Null Elem, otherwise they are
-- ordered according to their own ord instance (for the `compare (C x) (C y)` case).
instance Ord a => Ord (Elem a) where
    compare Null Null   = EQ
    compare Null x      = LT
    compare x    Null   = GT
    compare (C x) (C y) = compare x y

-- | Quickly build a tree without an initial tree. This should be used
-- to create an initial tree, using insert there after.
insert' :: Ord a => [a] -> TernarySet a
insert' (x:xs) = TNode (C x) TEnd (insert' xs) TEnd
insert' []     = TNode Null TEnd TEnd TEnd

-- | Inserts an entries into a tree.
insert :: Ord a => [a] -> TernarySet a -> TernarySet a
-- General case
insert xss@(x:xs) (TNode ele l e h) =
    case compare (C x) ele of
        LT -> TNode ele (insert xss l) e h
        EQ -> TNode ele l (insert xs e) h
        GT -> TNode ele l e (insert xss h)
-- Insert new elements quickly
insert xss@(x:xs) TEnd =
    insert' xss
-- TEnd of word in non empty tree
insert [] t@(TNode ele l e h) = 
    case compare Null ele of
        EQ -> t
        LT  -> TNode ele (insert [] l) e h
-- TEnd of word in empty tree
insert [] TEnd =
    TNode Null TEnd TEnd TEnd


-- | Returns true if the `[a]` is in the TernarySet
isElem :: Ord a => [a] -> TernarySet a -> Bool
isElem          _ TEnd              = False
isElem         [] (TNode ele l e h) = ele == Null || isElem [] l
isElem xss@(x:xs) (TNode ele l e h) = 
    case compare (C x) ele of
        LT -> isElem xss l
        EQ -> isElem  xs e
        GT -> isElem xss h

-- | Returns the number of non-Null Elems
treeSize :: TernarySet a -> Int
treeSize TEnd = 0
treeSize (TNode Null l e h) = treeSize l + treeSize e + treeSize h
treeSize (TNode _ l e h) = 1 + treeSize l + treeSize e + treeSize h

-- | Counts how many entries there are in the tree.
numEntries :: TernarySet a -> Int
numEntries TEnd = 0
numEntries (TNode Null l e h) = 1 + numEntries l + numEntries e + numEntries h
numEntries (TNode _ l e h) = numEntries l + numEntries e + numEntries h

-- | Creates a new tree from a list of 'strings'
fromList :: Ord a => [[a]] -> TernarySet a
fromList = foldl (flip insert) TEnd

instance Binary a => Binary (Elem a) where
    put Null = putWord8 0
    put (C x) = putWord8 1 >> put x
    get = do
        n <- getWord8
        case n of
            0 -> return Null
            1 -> liftM C get

-- | This binary uses the fact that the number of TEnds can be represented
-- in binary numbers to save a lot of space.
instance Binary a => Binary (TernarySet a) where
    put (TNode ch TEnd TEnd TEnd) = do
        putWord8 0
        put ch
    put (TNode ch TEnd TEnd h) = do
        putWord8 1
        put ch
        put h
    put (TNode ch TEnd e TEnd) = do
        putWord8 2
        put ch
        put e
    put (TNode ch TEnd e h) = do
        putWord8 3
        put ch
        put e
        put h
    put (TNode ch l TEnd TEnd) = do
        putWord8 4
        put ch
        put l
    put (TNode ch l TEnd h) = do
        putWord8 5
        put ch
        put l
        put h
    put (TNode ch l e TEnd) = do
        putWord8 6
        put ch
        put l
        put e
    -- General case
    put (TNode ch l e h) = do
        putWord8 7
        put ch
        put l
        put e
        put h
    put TEnd = putWord8 8
    get = do
        tag <- getWord8
        case tag of
            8 -> return TEnd
            _ -> do
                ch <- get
                case tag of
                    0 -> return (TNode ch TEnd TEnd TEnd)
                    1 -> do
                        h <- get
                        return (TNode ch TEnd TEnd h)
                    2 -> do
                        e <- get
                        return (TNode ch TEnd e TEnd)
                    3 -> do
                        e <- get
                        h <- get
                        return (TNode ch TEnd e h)
                    4 -> do
                        l <- get
                        return (TNode ch l TEnd TEnd)
                    5 -> do
                        l <- get
                        h <- get
                        return (TNode ch l TEnd h)
                    6 -> do
                        l <- get
                        e <- get
                        return (TNode ch l e TEnd)
                    7 -> do
                        l <- get
                        e <- get
                        h <- get
                        return (TNode ch l e h)






    -- put TEnd = put (0 :: Word8)
    -- -- Quite common, so speecialised
    -- put (TNode ch TEnd TEnd TEnd) = do
    --     putWord8 1
    --     put ch
    -- -- Also common, basically what insert' produces.
    -- put (TNode ch TEnd e TEnd) = do
    --     putWord8 2
    --     put ch
    --     put e
    -- -- General case
    -- put (TNode ch l e h) = do
    --     putWord8 3
    --     put ch
    --     put l
    --     put e
    --     put h
    -- get = do
    --     tag <- getWord8
    --     case tag of
    --         0 -> return TEnd
    --         1 -> do
    --             ch <- get
    --             return (TNode ch TEnd TEnd TEnd)
    --         2 -> do
    --             ch <- get
    --             e <- get
    --             return (TNode ch TEnd e TEnd)
    --         3 -> liftM4 TNode get get get get