module Data.BitSet ( BitSet , empty , fromByteString , fromRange , toList , isEmpty , intersect , union , subtract ) where import Data.Bits import qualified Data.ByteString as B import Data.Word import Prelude hiding (subtract) data BitSet a = BitSet B.ByteString instance Eq (BitSet a) where (BitSet b1) == (BitSet b2) = b1' == b2' where (b1', b2') = byteStringsPad b1 b2 instance Show a => Show (BitSet a) where show = show . toList empty :: Num a => BitSet a empty = BitSet B.empty fromByteString :: Num a => B.ByteString -> BitSet a fromByteString = BitSet fromRange :: Int -> Int -> BitSet a fromRange lo hi = BitSet generate where generate | lo < 0 = error "lower bound cannot be less than zero." | lo > hi = error "lower bound cannot be greater than upper bound." | lo == hi = B.empty | lo == 0 && hiBit == 0 = setBytes | loBit == 0 && hiBit == 0 = B.concat [clearBytes, setBytes] | loBit == 0 && hiBit /= 0 = B.concat [clearBytes, setBytes, fallByte] | loBit /= 0 && hiBit == 0 = B.concat [clearBytes, riseByte, setBytes] | loByteFloor == hiByteFloor = B.concat [clearBytes, humpByte] | loByteCeiling == hiByteFloor = B.concat [clearBytes, riseByte, fallByte] | loByteCeiling < hiByteFloor = B.concat [clearBytes, riseByte, setBytes, fallByte] | otherwise = error "cannot happen" (loByteFloor, loBit) = lo `divMod` 8 (hiByteFloor, hiBit) = hi `divMod` 8 loByteCeiling = (lo + 7) `div` 8 --hiByteCeiling = (hi + 7) `div` 8 clearBytes = B.replicate (fromIntegral loByteFloor) 0x00 setBytes = B.replicate (fromIntegral (hiByteFloor - loByteCeiling)) 0xff riseByte = B.singleton $ setBits 0 loBit 8 fallByte = B.singleton $ setBits 0 0 hiBit humpByte = B.singleton $ setBits 0 loBit hiBit toList :: BitSet a -> [Int] toList (BitSet b) = map snd $ filter fst $ zip (byteStringBits b) [0..] isEmpty :: BitSet a -> Bool isEmpty (BitSet b) = B.all (== 0) b intersect :: BitSet a -> BitSet a -> BitSet a union :: BitSet a -> BitSet a -> BitSet a subtract :: BitSet a -> BitSet a -> BitSet a intersect = binaryOp (.&.) union = binaryOp (.|.) subtract = binaryOp (\x y -> x .&. complement y) binaryOp :: (Word8 -> Word8 -> Word8) -> BitSet a -> BitSet b -> BitSet c binaryOp f (BitSet b1) (BitSet b2) = BitSet $ byteStringPackZipWith f b1' b2' where (b1', b2') = byteStringsPad b1 b2 byteStringBits :: B.ByteString -> [Bool] byteStringBits byteString = do word <- B.unpack byteString word8Bits word byteStringPackZipWith :: (Word8 -> Word8 -> Word8) -> B.ByteString -> B.ByteString -> B.ByteString byteStringPackZipWith = ((B.pack .) .) . B.zipWith byteStringsPad :: B.ByteString -> B.ByteString -> (B.ByteString, B.ByteString) byteStringsPad b1 b2 = if length1 < length2 then (B.append b1 (B.replicate (length2 - length1) 0), b2) else (b1, B.append b2 (B.replicate (length1 - length2) 0)) where length1 = B.length b1 length2 = B.length b2 setBits :: Bits a => a -> Int -> Int -> a setBits acc loBit hiBit | loBit < hiBit = setBits (setBit acc loBit) (loBit + 1) hiBit | otherwise = acc word8Bits :: Word8 -> [Bool] word8Bits w = map (testBit w) [0 .. 7]