{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UnicodeSyntax #-}
module Data.Enum.Set.Base
(
EnumSet
, empty
, singleton
, fromFoldable
, insert
, delete
, member
, notMember
, null
, size
, isSubsetOf
, union
, difference
, (\\)
, symmetricDifference
, intersection
, filter
, partition
, map
, map'
, foldl, foldl', foldr, foldr'
, foldl1, foldl1', foldr1, foldr1'
, foldMap
, traverse
, any
, all
, minimum
, maximum
, deleteMin
, deleteMax
, minView
, maxView
, toList
, fromRaw
) where
import qualified GHC.Exts
import qualified Data.Foldable as F
import Prelude hiding (all, any, filter, foldl, foldl1, foldMap, foldr, foldr1, map, maximum, minimum, null, traverse)
import Control.Applicative (liftA2)
import Control.DeepSeq (NFData)
import Control.Monad
import Data.Aeson (ToJSON(..))
import Data.Bits
import Data.Data (Data)
import Data.Monoid (Monoid(..))
import Data.Vector.Unboxed (Vector, MVector, Unbox)
import Foreign.Storable (Storable)
import GHC.Exts (IsList(Item), build)
import Text.Read
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Primitive as P
import qualified Data.Containers
import Data.Containers (SetContainer, IsSet)
import qualified Data.MonoTraversable
import Data.MonoTraversable (Element, GrowingAppend, MonoFoldable, MonoFunctor, MonoPointed, MonoTraversable)
newtype EnumSet word a = EnumSet word
deriving (Eq, Ord, Data, Storable, NFData, P.Prim, Unbox)
newtype instance MVector s (EnumSet word a) = MV_EnumSet (P.MVector s (EnumSet word a))
newtype instance Vector (EnumSet word a) = V_EnumSet (P.Vector (EnumSet word a))
instance P.Prim word => M.MVector MVector (EnumSet word a) where
basicLength (MV_EnumSet v) = M.basicLength v
{-# INLINE basicLength #-}
basicUnsafeSlice i n (MV_EnumSet v) = MV_EnumSet $ M.basicUnsafeSlice i n v
{-# INLINE basicUnsafeSlice #-}
basicOverlaps (MV_EnumSet v1) (MV_EnumSet v2) = M.basicOverlaps v1 v2
{-# INLINE basicOverlaps #-}
basicUnsafeNew n = MV_EnumSet `liftM` M.basicUnsafeNew n
{-# INLINE basicUnsafeNew #-}
basicInitialize (MV_EnumSet v) = M.basicInitialize v
{-# INLINE basicInitialize #-}
basicUnsafeReplicate n x = MV_EnumSet `liftM` M.basicUnsafeReplicate n x
{-# INLINE basicUnsafeReplicate #-}
basicUnsafeRead (MV_EnumSet v) i = M.basicUnsafeRead v i
{-# INLINE basicUnsafeRead #-}
basicUnsafeWrite (MV_EnumSet v) i x = M.basicUnsafeWrite v i x
{-# INLINE basicUnsafeWrite #-}
basicClear (MV_EnumSet v) = M.basicClear v
{-# INLINE basicClear #-}
basicSet (MV_EnumSet v) x = M.basicSet v x
{-# INLINE basicSet #-}
basicUnsafeCopy (MV_EnumSet v1) (MV_EnumSet v2) = M.basicUnsafeCopy v1 v2
{-# INLINE basicUnsafeCopy #-}
basicUnsafeMove (MV_EnumSet v1) (MV_EnumSet v2) = M.basicUnsafeMove v1 v2
{-# INLINE basicUnsafeMove #-}
basicUnsafeGrow (MV_EnumSet v) n = MV_EnumSet `liftM` M.basicUnsafeGrow v n
{-# INLINE basicUnsafeGrow #-}
instance P.Prim word => G.Vector Vector (EnumSet word a) where
basicUnsafeFreeze (MV_EnumSet v) = V_EnumSet `liftM` G.basicUnsafeFreeze v
{-# INLINE basicUnsafeFreeze #-}
basicUnsafeThaw (V_EnumSet v) = MV_EnumSet `liftM` G.basicUnsafeThaw v
{-# INLINE basicUnsafeThaw #-}
basicLength (V_EnumSet v) = G.basicLength v
{-# INLINE basicLength #-}
basicUnsafeSlice i n (V_EnumSet v) = V_EnumSet $ G.basicUnsafeSlice i n v
{-# INLINE basicUnsafeSlice #-}
basicUnsafeIndexM (V_EnumSet v) i = G.basicUnsafeIndexM v i
{-# INLINE basicUnsafeIndexM #-}
basicUnsafeCopy (MV_EnumSet mv) (V_EnumSet v) = G.basicUnsafeCopy mv v
{-# INLINE basicUnsafeCopy #-}
elemseq _ = seq
{-# INLINE elemseq #-}
instance Bits w => Semigroup (EnumSet w a) where
(<>) = union
{-# INLINE (<>) #-}
instance Bits w => Monoid (EnumSet w a) where
mempty = empty
{-# INLINE mempty #-}
instance (Bits w, Enum a) => MonoPointed (EnumSet w a) where
opoint = singleton
{-# INLINE opoint #-}
instance (FiniteBits w, Num w, Enum a) => IsList (EnumSet w a) where
type Item (EnumSet w a) = a
fromList = fromFoldable
{-# INLINE fromList #-}
toList = toList
{-# INLINE toList #-}
instance (FiniteBits w, Num w, Enum a, ToJSON a) => ToJSON (EnumSet w a) where
toJSON = toJSON . toList
{-# INLINE toJSON #-}
toEncoding = toEncoding . toList
{-# INLINE toEncoding #-}
type instance Element (EnumSet w a) = a
instance (FiniteBits w, Num w, Enum a) => MonoFunctor (EnumSet w a) where
omap = map
{-# INLINE omap #-}
instance (FiniteBits w, Num w, Enum a) => MonoFoldable (EnumSet w a) where
ofoldMap = foldMap
{-# INLINE ofoldMap #-}
ofoldr = foldr
{-# INLINE ofoldr #-}
ofoldl' = foldl'
{-# INLINE ofoldl' #-}
ofoldr1Ex = foldr1
{-# INLINE ofoldr1Ex #-}
ofoldl1Ex' = foldl1'
{-# INLINE ofoldl1Ex' #-}
otoList = toList
{-# INLINE otoList #-}
oall = all
{-# INLINE oall #-}
oany = any
{-# INLINE oany #-}
onull = null
{-# INLINE onull #-}
olength = size
{-# INLINE olength #-}
olength64 w = fromIntegral $ size w
{-# INLINE olength64 #-}
headEx = minimum
{-# INLINE headEx #-}
lastEx = maximum
{-# INLINE lastEx #-}
oelem = member
{-# INLINE oelem #-}
onotElem x = not . member x
{-# INLINE onotElem #-}
instance (FiniteBits w, Num w, Enum a) => GrowingAppend (EnumSet w a)
instance (FiniteBits w, Num w, Enum a) => MonoTraversable (EnumSet w a) where
otraverse = traverse
{-# INLINE otraverse #-}
instance (FiniteBits w, Num w, Eq a, Enum a) => SetContainer (EnumSet w a) where
type ContainerKey (EnumSet w a) = a
member = member
{-# INLINE member #-}
notMember = notMember
{-# INLINE notMember #-}
union = union
{-# INLINE union #-}
difference = difference
{-# INLINE difference #-}
intersection = intersection
{-# INLINE intersection #-}
keys = toList
{-# INLINE keys #-}
instance (FiniteBits w, Num w, Eq a, Enum a) => IsSet (EnumSet w a) where
insertSet = insert
{-# INLINE insertSet #-}
deleteSet = delete
{-# INLINE deleteSet #-}
singletonSet = singleton
{-# INLINE singletonSet #-}
setFromList = fromFoldable
{-# INLINE setFromList #-}
setToList = toList
{-# INLINE setToList #-}
filterSet = filter
{-# INLINE filterSet #-}
instance (FiniteBits w, Num w, Enum x, Show x) => Show (EnumSet w x) where
showsPrec p xs = showParen (p > 10) $
showString "fromList " . shows (toList xs)
{-# INLINABLE showsPrec #-}
instance (Bits w, Num w, Enum x, Read x) => Read (EnumSet w x) where
readPrec = parens $ prec 10 do
Ident "fromList" <- lexP
fromFoldable <$> (readPrec :: ReadPrec [x])
{-# INLINABLE readPrec #-}
readListPrec = readListPrecDefault
{-# INLINABLE readListPrec #-}
empty :: ∀ w a. Bits w
=> EnumSet w a
empty = EnumSet zeroBits
{-# INLINE empty #-}
singleton :: ∀ w a. (Bits w, Enum a)
=> a -> EnumSet w a
singleton = EnumSet . bit . fromEnum
{-# INLINE singleton #-}
fromFoldable :: ∀ f w a. (Foldable f, Bits w, Enum a)
=> f a -> EnumSet w a
fromFoldable = EnumSet . F.foldl' (flip $ (.|.) . bit . fromEnum) zeroBits
insert :: ∀ w a. (Bits w, Enum a)
=> a -> EnumSet w a -> EnumSet w a
insert !x (EnumSet w) = EnumSet . setBit w $ fromEnum x
delete :: ∀ w a. (Bits w, Enum a)
=> a -> EnumSet w a -> EnumSet w a
delete !x (EnumSet w) = EnumSet . clearBit w $ fromEnum x
member :: ∀ w a. (Bits w, Enum a)
=> a -> EnumSet w a -> Bool
member !x (EnumSet w) = testBit w $ fromEnum x
notMember :: ∀ w a. (Bits w, Enum a)
=> a -> EnumSet w a -> Bool
notMember !x = not . member x
null :: ∀ w a. Bits w
=> EnumSet w a -> Bool
null (EnumSet w) = zeroBits == w
{-# INLINE null #-}
size :: ∀ w a. (Bits w, Num w)
=> EnumSet w a -> Int
size (EnumSet !w) = popCount w
isSubsetOf :: ∀ w a. (Bits w)
=> EnumSet w a -> EnumSet w a -> Bool
isSubsetOf (EnumSet x) (EnumSet y) = x .|. y == y
{-# INLINE isSubsetOf #-}
union :: ∀ w a. Bits w
=> EnumSet w a -> EnumSet w a -> EnumSet w a
union (EnumSet x) (EnumSet y) = EnumSet $ x .|. y
{-# INLINE union #-}
difference :: ∀ w a. Bits w
=> EnumSet w a -> EnumSet w a -> EnumSet w a
difference (EnumSet x) (EnumSet y) = EnumSet $ (x .|. y) `xor` y
{-# INLINE difference #-}
(\\) :: ∀ w a. Bits w
=> EnumSet w a -> EnumSet w a -> EnumSet w a
(\\) = difference
infixl 9 \\
{-# INLINE (\\) #-}
symmetricDifference :: ∀ w a. Bits w
=> EnumSet w a -> EnumSet w a -> EnumSet w a
symmetricDifference (EnumSet x) (EnumSet y) = EnumSet $ x `xor` y
{-# INLINE symmetricDifference #-}
intersection :: ∀ w a. Bits w
=> EnumSet w a -> EnumSet w a -> EnumSet w a
intersection (EnumSet x) (EnumSet y) = EnumSet $ x .&. y
{-# INLINE intersection #-}
filter :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> (a -> Bool) -> EnumSet w a -> EnumSet w a
filter p (EnumSet w) = EnumSet $ foldlBits' f 0 w
where
f z i
| p $ toEnum i = setBit z i
| otherwise = z
{-# INLINE f #-}
partition :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> (a -> Bool) -> EnumSet w a -> (EnumSet w a, EnumSet w a)
partition p (EnumSet w) = (EnumSet yay, EnumSet nay)
where
(yay, nay) = foldlBits' f (0, 0) w
f (x, y) i
| p $ toEnum i = (setBit x i, y)
| otherwise = (x, setBit y i)
{-# INLINE f #-}
map :: ∀ w a b. (FiniteBits w, Num w, Enum a, Enum b)
=> (a -> b) -> EnumSet w a -> EnumSet w b
map = map'
{-# INLINE map #-}
map' :: ∀ v w a b. (FiniteBits v, FiniteBits w, Num v, Num w, Enum a, Enum b)
=> (a -> b) -> EnumSet v a -> EnumSet w b
map' f0 (EnumSet w) = EnumSet $ foldlBits' f 0 w
where
f z i = setBit z $ fromEnum $ f0 (toEnum i)
{-# INLINE f #-}
foldl :: ∀ w a b. (FiniteBits w, Num w, Enum a)
=> (b -> a -> b) -> b -> EnumSet w a -> b
foldl f z (EnumSet w) = foldlBits ((. toEnum) . f) z w
{-# INLINE foldl #-}
foldl' :: ∀ w a b. (FiniteBits w, Num w, Enum a)
=> (b -> a -> b) -> b -> EnumSet w a -> b
foldl' f z (EnumSet w) = foldlBits' ((. toEnum) . f) z w
{-# INLINE foldl' #-}
foldr :: ∀ w a b. (FiniteBits w, Num w, Enum a)
=> (a -> b -> b) -> b -> EnumSet w a -> b
foldr f z (EnumSet w) = foldrBits (f . toEnum) z w
{-# INLINE foldr #-}
foldr' :: ∀ w a b. (FiniteBits w, Num w, Enum a)
=> (a -> b -> b) -> b -> EnumSet w a -> b
foldr' f z (EnumSet w) = foldrBits' (f . toEnum) z w
{-# INLINE foldr' #-}
foldl1 :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> (a -> a -> a) -> EnumSet w a -> a
foldl1 f = fold1Aux lsb $ foldlBits ((. toEnum) . f)
{-# INLINE foldl1 #-}
foldl1' :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> (a -> a -> a) -> EnumSet w a -> a
foldl1' f = fold1Aux lsb $ foldlBits' ((.toEnum) . f)
{-# INLINE foldl1' #-}
foldr1 :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> (a -> a -> a) -> EnumSet w a -> a
foldr1 f = fold1Aux msb $ foldrBits (f . toEnum)
{-# INLINE foldr1 #-}
foldr1' :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> (a -> a -> a) -> EnumSet w a -> a
foldr1' f = fold1Aux msb $ foldrBits' (f . toEnum)
{-# INLINE foldr1' #-}
foldMap :: ∀ m w a. (Monoid m, FiniteBits w, Num w, Enum a)
=> (a -> m) -> EnumSet w a -> m
foldMap f (EnumSet w) = foldrBits (mappend . f . toEnum) mempty w
{-# INLINE foldMap #-}
traverse :: ∀ f w a. (Applicative f, FiniteBits w, Num w, Enum a)
=> (a -> f a) -> EnumSet w a -> f (EnumSet w a)
traverse f (EnumSet w) = EnumSet <$>
foldrBits
(liftA2 (flip setBit) . fmap fromEnum . f . toEnum)
(pure zeroBits)
w
{-# INLINE traverse #-}
all :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> (a -> Bool) -> EnumSet w a -> Bool
all p (EnumSet w) = let lb = lsb w in go lb (w `unsafeShiftR` lb)
where
go !_ 0 = True
go bi n
| n `testBit` 0 && not (p $ toEnum bi) = False
| otherwise = go (bi + 1) (n `unsafeShiftR` 1)
any :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> (a -> Bool) -> EnumSet w a -> Bool
any p (EnumSet w) = let lb = lsb w in go lb (w `unsafeShiftR` lb)
where
go !_ 0 = False
go bi n
| n `testBit` 0 && p (toEnum bi) = True
| otherwise = go (bi + 1) (n `unsafeShiftR` 1)
minimum :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> EnumSet w a -> a
minimum (EnumSet 0) = error "empty EnumSet"
minimum (EnumSet w) = toEnum $ lsb w
maximum :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> EnumSet w a -> a
maximum (EnumSet 0) = error "empty EnumSet"
maximum (EnumSet w) = toEnum $ msb w
deleteMin :: ∀ w a. (FiniteBits w, Num w)
=> EnumSet w a -> EnumSet w a
deleteMin (EnumSet 0) = EnumSet 0
deleteMin (EnumSet w) = EnumSet $ clearBit w $ lsb w
deleteMax :: ∀ w a. (FiniteBits w, Num w)
=> EnumSet w a -> EnumSet w a
deleteMax (EnumSet 0) = EnumSet 0
deleteMax (EnumSet w) = EnumSet $ clearBit w $ msb w
minView :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> EnumSet w a -> Maybe (a, EnumSet w a)
minView (EnumSet 0) = Nothing
minView (EnumSet w) = let i = lsb w in Just (toEnum i, EnumSet $ clearBit w i)
maxView :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> EnumSet w a -> Maybe (a, EnumSet w a)
maxView (EnumSet 0) = Nothing
maxView (EnumSet w) = let i = msb w in Just (toEnum i, EnumSet $ clearBit w i)
toList :: ∀ w a. (FiniteBits w, Num w, Enum a)
=> EnumSet w a -> [a]
toList (EnumSet w) = build \c n -> foldrBits (c . toEnum) n w
{-# INLINE toList #-}
fromRaw :: ∀ w a. w -> EnumSet w a
fromRaw = EnumSet
{-# INLINE fromRaw #-}
lsb :: ∀ w. (FiniteBits w, Num w) => w -> Int
lsb n0 = go 0 n0 $ finiteBitSize n0 `quot` 2
where
go b n 1 = case n .&. 1 of
0 -> 1 + b
_ -> b
go b n i = case n .&. (bit i - 1) of
0 -> go (i + b) (n `unsafeShiftR` i) (i `quot` 2)
_ -> go b n (i `quot` 2)
{-# INLINE lsb #-}
msb :: ∀ w. (FiniteBits w, Num w) => w -> Int
msb n0 = go 0 n0 $ finiteBitSize n0 `quot` 2
where
go b n 1 = case n .&. 2 of
0 -> b
_ -> 1 + b
go b n i = case n .&. (bit (i * 2) - bit i) of
0 -> go b n (i `quot` 2)
_ -> go (i + b) (n `unsafeShiftR` i) (i `quot` 2)
{-# INLINE msb #-}
foldlBits :: ∀ w a. (FiniteBits w, Num w) => (a -> Int -> a) -> a -> w -> a
foldlBits f z w = let lb = lsb w in go lb z (w `unsafeShiftR` lb)
where
go !_ acc 0 = acc
go bi acc n
| n `testBit` 0 = go (bi + 1) (f acc bi) (n `unsafeShiftR` 1)
| otherwise = go (bi + 1) acc (n `unsafeShiftR` 1)
{-# INLINE foldlBits #-}
foldlBits' :: ∀ w a. (FiniteBits w, Num w) => (a -> Int -> a) -> a -> w -> a
foldlBits' f z w = let lb = lsb w in go lb z (w `unsafeShiftR` lb)
where
go !_ !acc 0 = acc
go bi acc n
| n `testBit` 0 = go (bi + 1) (f acc bi) (n `unsafeShiftR` 1)
| otherwise = go (bi + 1) acc (n `unsafeShiftR` 1)
{-# INLINE foldlBits' #-}
foldrBits :: ∀ w a. (FiniteBits w, Num w) => (Int -> a -> a) -> a -> w -> a
foldrBits f z w = let lb = lsb w in go lb (w `unsafeShiftR` lb)
where
go !_ 0 = z
go bi n
| n `testBit` 0 = f bi (go (bi + 1) (n `unsafeShiftR` 1))
| otherwise = go (bi + 1) (n `unsafeShiftR` 1)
{-# INLINE foldrBits #-}
foldrBits' :: ∀ w a. (FiniteBits w, Num w) => (Int -> a -> a) -> a -> w -> a
foldrBits' f z w = let lb = lsb w in go lb (w `unsafeShiftR` lb)
where
go !_ 0 = z
go bi n
| n `testBit` 0 = f bi $! go (bi + 1) (n `unsafeShiftR` 1)
| otherwise = go (bi + 1) (n `unsafeShiftR` 1)
{-# INLINE foldrBits' #-}
fold1Aux :: ∀ w a. (Bits w, Num w, Enum a)
=> (w -> Int) -> (a -> w -> a) -> EnumSet w a -> a
fold1Aux _ _ (EnumSet 0) = error "empty EnumSet"
fold1Aux getBit f (EnumSet w) = f (toEnum gotBit) (clearBit w gotBit)
where
gotBit = getBit w
{-# INLINE fold1Aux #-}