module Data.BitSet
	( BitSet
	, empty
	, fromByteString
	, fromRange
	, toList
	, isEmpty
	, intersect
	, union
	, subtract
	)
	where

import Control.Exception
import Data.Bits
import qualified Data.ByteString as B
import Data.Word
import Prelude hiding (subtract)

data BitSet = BitSet B.ByteString

instance Eq BitSet
	where (BitSet b1) == (BitSet b2) = b1' == b2'
		where (b1', b2') = byteStringsPad b1 b2

instance Show BitSet where show = show . toList

empty :: BitSet
empty = BitSet B.empty

fromByteString :: B.ByteString -> BitSet
fromByteString = BitSet

fromRange :: Int -> Int -> BitSet
fromRange lo hi = BitSet generate where

	generate
		| lo <  0  = error "lower bound cannot be less than zero."
		| lo >  hi = error "lower bound cannot be greater than upper bound."
		| lo == hi = B.empty
		| lo    == 0 && hiBit == 0 = setBytes
		| loBit == 0 && hiBit == 0 = B.concat [clearBytes, setBytes]
		| loBit == 0 && hiBit /= 0 = B.concat [clearBytes, setBytes, fallByte]
		| loBit /= 0 && hiBit == 0 = B.concat [clearBytes, riseByte, setBytes]
		| loByteFloor   == hiByteFloor = B.concat [clearBytes, humpByte]
		| loByteCeiling == hiByteFloor = B.concat [clearBytes, riseByte, fallByte]
		| loByteCeiling <  hiByteFloor = B.concat [clearBytes, riseByte, setBytes, fallByte]

	(loBit, loByteFloor, loByteCeiling) = (lo `mod` 8, lo `div` 8, (lo + 7) `div` 8)
	(hiBit, hiByteFloor, hiByteCeiling) = (hi `mod` 8, hi `div` 8, (hi + 7) `div` 8)

	clearBytes = B.replicate (loByteFloor                ) 0x00
	setBytes   = B.replicate (hiByteFloor - loByteCeiling) 0xff

	riseByte = B.singleton $ setBits 0 loBit     8
	fallByte = B.singleton $ setBits 0     0 hiBit
	humpByte = B.singleton $ setBits 0 loBit hiBit

toList :: BitSet -> [Int]
toList (BitSet b) = map snd $ filter fst $ zip (byteStringBits b) [0 ..]

isEmpty :: BitSet -> Bool
isEmpty (BitSet b) = B.all (== 0) b

intersect :: BitSet -> BitSet -> BitSet
union     :: BitSet -> BitSet -> BitSet
subtract  :: BitSet -> BitSet -> BitSet

intersect = binaryOp (.&.)
union     = binaryOp (.|.)
subtract  = binaryOp (\x y -> x .&. complement y)

binaryOp f (BitSet b1) (BitSet b2) =
	BitSet $ byteStringPackZipWith f b1' b2'
	where (b1', b2') = byteStringsPad b1 b2

byteStringBits byteString = do
	word <- B.unpack byteString
	bit <- word8Bits word
	return bit

byteStringPackZipWith :: (Word8 -> Word8 -> Word8) -> B.ByteString -> B.ByteString -> B.ByteString
byteStringPackZipWith = ((B.pack .) .) . B.zipWith

byteStringsPad :: B.ByteString -> B.ByteString -> (B.ByteString, B.ByteString)
byteStringsPad b1 b2 =
	if length1 < length2
		then (B.append b1 (B.replicate (length2 - length1) 0), b2)
		else (b1, B.append b2 (B.replicate (length1 - length2) 0))
	where
		length1 = B.length b1
		length2 = B.length b2

setBits :: Bits a => a -> Int -> Int -> a
setBits value loBit hiBit = if loBit < hiBit
	then setBits (setBit value loBit) (loBit + 1) hiBit
	else value

word8Bits :: Word8 -> [Bool]
word8Bits w = map (testBit w) [0 .. 7]