{-# LANGUAGE FlexibleInstances, FlexibleContexts, MultiParamTypeClasses, DeriveDataTypeable, BangPatterns, PatternGuards, TypeFamilies #-}
module Data.Ring.Semi.BitSet
    ( module Data.Monoid.Reducer
    , BitSet
    , empty
    , singleton
    , null
    , full
    , complement
    , insert
    , delete
    , fromList
    , fromDistinctAscList
    , toInteger
    , (\\)
    , member
    , size
    ) where

import Prelude hiding ( null, exponent, toInteger )
import Data.Bits hiding ( complement )
import qualified Data.Bits as Bits
import Data.Data
import Data.Ring.Semi.Natural
import Data.Monoid.Reducer
import Data.Generator
import Data.Ring.Algebra

data BitSet a = BS 
        { _countAtLeast  :: {-# UNPACK #-} !Int       -- ^ a conservative upper bound on the element count
        , _countAtMost   :: {-# UNPACK #-} !Int       -- ^ a conservative lower bound on the element count
        , _count         :: Int                       -- ^ the actual element count (lazy) used when the above two disagree
        , exponent       :: {-# UNPACK #-} !Int       -- ^ low water mark
        , _hwm           :: {-# UNPACK #-} !Int       -- ^ high water mark
        , mantissa       :: {-# UNPACK #-} !Integer   -- ^ the set of bits. TODO: negative mantissa = complement
        , _universe      :: (Int,Int)                 -- ^ invariant: mantissa < 0 => universe = (fromEnum minBound,fromEnum maxBound)
        } deriving (Data, Typeable,Show)

debug :: BitSet a -> (Int,Int,Int,Int,Int,Integer)
debug (BS a b c d e f _) = (a,b,c,d,e,f)

-- | internal smart constructor: makes sure the count is forced when known
bs :: Int -> Int -> Int -> Int -> Int -> Integer -> (Int,Int) -> BitSet a
bs !a !b c !l !h !m u | a == b = BS a a a l h m u
                      | otherwise = BS a b c l h m u
{-# INLINE bs #-}

-- instance (Enum a, Show a) => Show (BitSet a) where
--    show s = "fromDistinctAscList " ++ show (toList s) ++ 

-- | /O(d)/ where /d/ is absolute deviation in fromEnum from the least element in the set.
toList :: Enum a => BitSet a -> [a]
toList (BS _ _ _ l h m u) 
    | m < 0 = map toEnum [ul..max (pred l) ul] ++ toList' l (map toEnum [min (succ h) uh..uh])
    | otherwise = toList' 0 []
    where
        ~(ul,uh) = u
        toList' :: Enum a => Int -> [a] -> [a]
        toList' !n t | n > h = t
                     | testBit m (n - l) = toEnum n : toList' (n+1) t
                     | otherwise         = toList' (n+1) t
{-# INLINE toList #-}

-- | The empty bit set.
empty :: BitSet a
empty = BS 0 0 0 0 0 0 undefined
{-# INLINE empty #-}

singleton :: Enum a => a -> BitSet a 
singleton x = BS 1 1 1 e e 1 undefined where e = fromEnum x
{-# INLINE singleton #-}

-- | Is the bit set empty? Asymptotically faster than checking if size == 0 in some cases.
null :: BitSet a -> Bool
null (BS a b c _ _ _ _) 
    | a > 0     = False
    | b == 0    = True
    | otherwise = c == 0 
{-# INLINE null #-}

full :: (Enum a, Bounded a) => BitSet a
full = complement empty 

universeOf :: (Bounded a, Enum a) => BitSet a -> (Int,Int)
universeOf x = (fromEnum (minBound `asArgTypeOf` x), fromEnum (maxBound `asArgTypeOf` x))

-- ensures valid universe, may result in negative bitset, note recalculation of universe
complement :: (Enum a, Bounded a) => BitSet a -> BitSet a 
complement r@(BS a b c l h m _) = BS (Bits.complement b) (Bits.complement a) (Bits.complement c) l h (Bits.complement m) (universeOf r)

-- proof obligation: either the value is already complemented or it is a complement-complement, note retention of u
recomplement :: BitSet a -> BitSet a 
recomplement (BS a b c l h m u) = BS (Bits.complement b) (Bits.complement a) (Bits.complement c) l h (Bits.complement m) u

-- | /O(d * n)/ Make a @BitSet@ from a list of items.
fromList :: Enum a => [a] -> BitSet a
fromList = foldr insert empty 
{-# INLINE fromList #-}

fromDistinctAscList :: Enum a => [a] -> BitSet a 
fromDistinctAscList [] = empty
fromDistinctAscList (c:cs) = fromDistinctAscList' cs 1 0 1 
    where
        l = fromEnum c
        fromDistinctAscList' :: Enum a => [a] -> Int -> Int -> Integer -> BitSet a
        fromDistinctAscList' [] !n !h !m  = BS n n n l h m undefined
        fromDistinctAscList' (c':cs') !n _ !m = fromDistinctAscList' cs' (n+1) h' (setBit m (h' - l))
            where
                h' = fromEnum c'
{-# INLINE fromDistinctAscList #-}

-- | /O(d)/ Insert an item into the bit set.
insert :: Enum a => a -> BitSet a -> BitSet a
insert x r@(BS a b c l h m u) 
    | m < 0, e < l = r 
    | m < 0, e > h = r
    | e < l = bs (a+1) (b+1) (c+1) e (h - e) (shiftL m (l - e) .|. 1) u
    | e > h = bs (a+1) (b+1) (c+1) l p (setBit m p) u
    | testBit m (e - l) = r 
    | otherwise = bs (a+1) (b+1) (c+1) l h (setBit m p) u
    where 
        e = fromEnum x
        p = e - l 
{-# INLINE insert #-}

-- | /O(d)/ Delete an item from the bit set.
delete :: Enum a => a -> BitSet a -> BitSet a
delete x r@(BS a b c l h m u) 
    | m < 0, e < l = bs (a+1) (b+1) (c+1) e (h - e) (shiftL m (l - e) .&. Bits.complement 1) u
    | m < 0, e > h = bs (a+1) (b+1) (c+1) l p (clearBit m p) u
    | e < l       = r
    | e > h       = r
    | testBit m p = bs (a-1) (b-1) (c-1) l h (clearBit m p) u
    | otherwise   = r
    where 
        e = fromEnum x
        p = e - l
{-# INLINE delete #-}

-- | /O(testBit on Integer)/ Ask whether the item is in the bit set.
member :: Enum a => a -> BitSet a -> Bool
member x (BS _ _ _ l h m _) 
    | e < l     = m < 0 
    | e > h     = m > 0
    | otherwise = testBit m (e - l)
    where 
        e = fromEnum x
{-# INLINE member #-}

-- | /O(1)/ or /O(d)/ The number of elements in the bit set.
size :: BitSet a -> Int
size (BS a b c _ _ m (ul,uh)) 
    | a == b, m >= 0 = a
    | a == b         = uh - ul - a 
    | m >= 0         = c
    | otherwise      = uh - ul - c 

-- | /O(d)/ convert to an Integer representation. Discards negative elements
toInteger :: BitSet a -> Integer
toInteger x = mantissa x `shift` exponent x

union :: BitSet a -> BitSet a -> BitSet a 
union x@(BS a b c l h m u) y@(BS a' b' c' l' h' m' u')
    | l' < l    = union y x                                                         -- ensure left side has lower exponent
    | b == 0    = y                                                                 -- fast empty union
    | b' == 0   = x                                                                 -- fast empty union
    | a == -1   = BS (-1) (-1) (-1) 0 0 (-1) u                                      -- fast full union, recomplement obligation met by negative size
    | a' == -1  = BS (-1) (-1) (-1) 0 0 (-1) u'                                     -- fast full union, recomplement obligation met by negative size
    | m < 0, m' < 0 = recomplement (intersection (recomplement x) (recomplement y)) -- appeal to intersection, recomplement obligation met by 2s complement
    | m' < 0    = recomplement (pseudoDiff (recomplement y) x u')                      -- union with complement, recomplement obligation met by 2s complement -- THESE ARE WRONG FIX!
    | m < 0     = recomplement (pseudoDiff (recomplement x) y u)                      -- union with complement, recomplement obligation met by 2s complement -- THESE ARE WRONG FIX!
    | h < l'    = bs (a + a') (b + b') (c + c') l h' m'' u                          -- disjoint positive ranges
    | otherwise = bs (a `max` a') (b + b') (recount m'') l (h `max` h') m'' u       -- overlapped positives
    where 
        m'' = m .|. shiftL m' (l' - l)

intersection :: BitSet a -> BitSet a -> BitSet a 
intersection x@(BS a b _ l h m u) y@(BS a' b' _ l' h' m' u')
    | l' < l = intersection y x                                 
    | b == 0 = empty
    | b' == 0 = empty
    | a == -1 = y
    | a' == -1 = x
    | m < 0, m' < 0 = recomplement (union (recomplement x) (recomplement y))
    | m' < 0 = pseudoDiff x (recomplement y) u'
    | m < 0 = pseudoDiff y (recomplement x) u
    | h < l' = empty 
    | otherwise = bs 0 (b `min` b') (recount m'') l'' (h `min` h') m'' u
    where
        l'' = max l l'
        m'' = shift m (l'' - l) .&. shift m' (l'' - l')

-- we know m >= 0, m' >= 0, a /= -1, a' /= -1, b /= 0, b' /= 0, u' is the universe of discourse
pseudoDiff :: BitSet a -> BitSet a -> (Int,Int) -> BitSet a 
pseudoDiff x@(BS a _ _ l h m _) (BS _ b' _ l' h' m' _) u''
    | h < l' = x
    | h' < l = x
    | otherwise = bs (max (a - b') 0) a (recount m'') l h m'' u''
    where m'' = m .&. shift (Bits.complement m') (l' - l)

(\\) :: (Enum a, Bounded a) => BitSet a -> BitSet a -> BitSet a 
x \\ y = x `intersection` complement y

-- TODO: fix this so that it handles complements correctly
instance Eq (BitSet a) where
    BS _ _ _ l _ m _ == BS _ _ _ l' _ m' _ = shift m (l'' - l) == shift m' (l'' - l) where l'' = min l l'
    BS _ _ _ l _ m _ /= BS _ _ _ l' _ m' _ = shift m (l'' - l) /= shift m' (l'' - l) where l'' = min l l'

instance Ord (BitSet a) where
    BS _ _ _ l _ m _ `compare` BS _ _ _ l' _ m' _ = shift m (l'' - l) `compare` shift m' (l'' - l) where l'' = min l l'

instance (Enum a, Bounded a) => Bounded (BitSet a) where
    minBound = empty
    maxBound = result where
        result = BS n n n l h m (l,h)
        n = h - l + 1
        l = fromEnum (minBound `asArgTypeOf` result)
        h = fromEnum (maxBound `asArgTypeOf` result)
        m = setBit 0 n - 1

asArgTypeOf :: a -> f a -> a
asArgTypeOf = const
{-# INLINE asArgTypeOf #-}

recount :: Integer -> Int
recount = recount' 0 where
    recount' :: Int -> Integer -> Int
    recount' !n 0 = n
    recount' !n !m = recount' (if testBit m 0 then n+1 else n) (shiftR m 1)

-- note that operations on values generated by toEnum are pretty slow because the bounds are suboptimal
instance (Enum a, Bounded a) => Enum (BitSet a) where
    fromEnum b@(BS _ _ _ l _ m _) = fromInteger (shiftL m (l - l'))
        where 
            l' = fromEnum (minBound `asArgTypeOf` b)
    toEnum i = result 
        where
            result = BS a i (recount m) l h m undefined -- n <= 2^n, so i serves as a valid upper bound
            l = fromEnum (minBound `asArgTypeOf` result)
            h = fromEnum (maxBound `asArgTypeOf` result)
            m = fromIntegral i
            a | m /= 0 = 1 -- allow a fast null check, but not much else
              | otherwise = 0
        
instance Enum a => Monoid (BitSet a) where
    mempty = empty
    mappend = union

instance Enum a => Reducer a (BitSet a) where
    unit = singleton
    snoc = flip insert
    cons = insert

instance (Bounded a, Enum a) => Multiplicative (BitSet a) where
    one = full
    times = intersection

instance (Bounded a, Enum a) => LeftSemiNearRing (BitSet a)
instance (Bounded a, Enum a) => RightSemiNearRing (BitSet a)
instance (Bounded a, Enum a) => SemiRing (BitSet a)

-- idempotent monoid
instance Enum a => LeftModule Natural (BitSet a) where
    0 *. _ = empty
    _ *. m = m
instance Enum a => RightModule Natural (BitSet a) where
    _ .* 0 = empty
    m .* _ = m
instance Enum a => Module Natural (BitSet a)

instance (Bounded a, Enum a) => LeftModule (BitSet a) (BitSet a) where (*.) = times
instance (Bounded a, Enum a) => RightModule (BitSet a) (BitSet a) where (.*) = times
instance (Bounded a, Enum a) => Module (BitSet a) (BitSet a)

instance (Bounded a, Enum a) => Algebra Natural (BitSet a)
    
instance Enum a => Generator (BitSet a) where
    type Elem (BitSet a) = a
    mapReduce f = mapReduce f . toList