module Data.BitSet
( BitSet
, empty
, fromByteString
, fromRange
, toList
, isEmpty
, intersect
, union
, subtract
)
where
import Control.Exception
import Data.Bits
import qualified Data.ByteString as B
import Data.Word
import Prelude hiding (subtract)
data BitSet = BitSet B.ByteString
instance Eq BitSet
where (BitSet b1) == (BitSet b2) = b1' == b2'
where (b1', b2') = byteStringsPad b1 b2
instance Show BitSet where show = show . toList
empty :: BitSet
empty = BitSet B.empty
fromByteString :: B.ByteString -> BitSet
fromByteString = BitSet
fromRange :: Int -> Int -> BitSet
fromRange lo hi = BitSet generate where
generate
| lo < 0 = error "lower bound cannot be less than zero."
| lo > hi = error "lower bound cannot be greater than upper bound."
| lo == hi = B.empty
| lo == 0 && hiBit == 0 = setBytes
| loBit == 0 && hiBit == 0 = B.concat [clearBytes, setBytes]
| loBit == 0 && hiBit /= 0 = B.concat [clearBytes, setBytes, fallByte]
| loBit /= 0 && hiBit == 0 = B.concat [clearBytes, riseByte, setBytes]
| loByteFloor == hiByteFloor = B.concat [clearBytes, humpByte]
| loByteCeiling == hiByteFloor = B.concat [clearBytes, riseByte, fallByte]
| loByteCeiling < hiByteFloor = B.concat [clearBytes, riseByte, setBytes, fallByte]
(loBit, loByteFloor, loByteCeiling) = (lo `mod` 8, lo `div` 8, (lo + 7) `div` 8)
(hiBit, hiByteFloor, hiByteCeiling) = (hi `mod` 8, hi `div` 8, (hi + 7) `div` 8)
clearBytes = B.replicate (loByteFloor ) 0x00
setBytes = B.replicate (hiByteFloor loByteCeiling) 0xff
riseByte = B.singleton $ setBits 0 loBit 8
fallByte = B.singleton $ setBits 0 0 hiBit
humpByte = B.singleton $ setBits 0 loBit hiBit
toList :: BitSet -> [Int]
toList (BitSet b) = map snd $ filter fst $ zip (byteStringBits b) [0 ..]
isEmpty :: BitSet -> Bool
isEmpty (BitSet b) = B.all (== 0) b
intersect :: BitSet -> BitSet -> BitSet
union :: BitSet -> BitSet -> BitSet
subtract :: BitSet -> BitSet -> BitSet
intersect = binaryOp (.&.)
union = binaryOp (.|.)
subtract = binaryOp (\x y -> x .&. complement y)
binaryOp f (BitSet b1) (BitSet b2) =
BitSet $ byteStringPackZipWith f b1' b2'
where (b1', b2') = byteStringsPad b1 b2
byteStringBits byteString = do
word <- B.unpack byteString
bit <- word8Bits word
return bit
byteStringPackZipWith :: (Word8 -> Word8 -> Word8) -> B.ByteString -> B.ByteString -> B.ByteString
byteStringPackZipWith = ((B.pack .) .) . B.zipWith
byteStringsPad :: B.ByteString -> B.ByteString -> (B.ByteString, B.ByteString)
byteStringsPad b1 b2 =
if length1 < length2
then (B.append b1 (B.replicate (length2 length1) 0), b2)
else (b1, B.append b2 (B.replicate (length1 length2) 0))
where
length1 = B.length b1
length2 = B.length b2
setBits :: Bits a => a -> Int -> Int -> a
setBits value loBit hiBit = if loBit < hiBit
then setBits (setBit value loBit) (loBit + 1) hiBit
else value
word8Bits :: Word8 -> [Bool]
word8Bits w = map (testBit w) [0 .. 7]