{-# LANGUAGE TemplateHaskellQuotes #-}
module Sasha.Internal.Word8Set (
    -- * Set type
    Word8Set,
    Key,

    -- * Construction
    empty,
    full,
    singleton,
    range,
    fromList,

    -- * Insertion
    insert,

    -- * Deletion
    delete,

    -- * Query
    member,
    memberCode,
    isSubsetOf,
    null,
    isFull,
    isSingleRange,
    size,

    -- * Combine
    union,
    intersection,
    complement,

    -- * Min\/Max
    findMin,
    findMax,
    -- * Conversion to List
    elems,
    toList,
) where

import Prelude
       (Bool (..), Eq ((==)), Int, Monoid (..), Ord, Semigroup (..),
       Show (showsPrec), fromIntegral, negate, otherwise, showParen, showString,
       ($), (&&), (+), (-), (.), (<), (<=), (>), (||), return)

import Data.Bits             ((.&.), (.|.))
import Data.Foldable         (foldl')
import Data.WideWord.Word256 (Word256 (..))
import Data.Word             (Word64, Word8)
import Test.QuickCheck       (Arbitrary (..))
import Algebra.Lattice
       (BoundedJoinSemiLattice (..), BoundedMeetSemiLattice (..), Lattice (..))

import Language.Haskell.TH.Syntax

import qualified Data.Bits as Bits

-------------------------------------------------------------------------------
-- Types
-------------------------------------------------------------------------------

newtype Word8Set = W8S Word256
  deriving (Word8Set -> Word8Set -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Word8Set -> Word8Set -> Bool
$c/= :: Word8Set -> Word8Set -> Bool
== :: Word8Set -> Word8Set -> Bool
$c== :: Word8Set -> Word8Set -> Bool
Eq, Eq Word8Set
Word8Set -> Word8Set -> Bool
Word8Set -> Word8Set -> Ordering
Word8Set -> Word8Set -> Word8Set
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Word8Set -> Word8Set -> Word8Set
$cmin :: Word8Set -> Word8Set -> Word8Set
max :: Word8Set -> Word8Set -> Word8Set
$cmax :: Word8Set -> Word8Set -> Word8Set
>= :: Word8Set -> Word8Set -> Bool
$c>= :: Word8Set -> Word8Set -> Bool
> :: Word8Set -> Word8Set -> Bool
$c> :: Word8Set -> Word8Set -> Bool
<= :: Word8Set -> Word8Set -> Bool
$c<= :: Word8Set -> Word8Set -> Bool
< :: Word8Set -> Word8Set -> Bool
$c< :: Word8Set -> Word8Set -> Bool
compare :: Word8Set -> Word8Set -> Ordering
$ccompare :: Word8Set -> Word8Set -> Ordering
Ord)

type Key = Word8

-------------------------------------------------------------------------------
-- Instances
-------------------------------------------------------------------------------

instance Show Word8Set where
    showsPrec :: Int -> Word8Set -> ShowS
showsPrec Int
d Word8Set
xs = Bool -> ShowS -> ShowS
showParen (Int
d forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"fromList " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (Word8Set -> [Word8]
toList Word8Set
xs)

instance Lift Word8Set where
    liftTyped :: forall (m :: * -> *). Quote m => Word8Set -> Code m Word8Set
liftTyped (W8S (Word256 Word64
a Word64
b Word64
c Word64
d)) =
        [|| W8S (Word256 a b c d) ||]

instance Semigroup Word8Set where
    <> :: Word8Set -> Word8Set -> Word8Set
(<>) = Word8Set -> Word8Set -> Word8Set
union

instance Monoid Word8Set where
    mempty :: Word8Set
mempty = Word8Set
empty

instance Arbitrary Word8Set where
    arbitrary :: Gen Word8Set
arbitrary = do
        Word64
a <- forall a. Arbitrary a => Gen a
arbitrary
        Word64
b <- forall a. Arbitrary a => Gen a
arbitrary
        Word64
c <- forall a. Arbitrary a => Gen a
arbitrary
        Word64
d <- forall a. Arbitrary a => Gen a
arbitrary
        forall (m :: * -> *) a. Monad m => a -> m a
return (Word256 -> Word8Set
W8S (Word64 -> Word64 -> Word64 -> Word64 -> Word256
Word256 Word64
a Word64
b Word64
c Word64
d))

instance Lattice Word8Set where
    \/ :: Word8Set -> Word8Set -> Word8Set
(\/) = Word8Set -> Word8Set -> Word8Set
union
    /\ :: Word8Set -> Word8Set -> Word8Set
(/\) = Word8Set -> Word8Set -> Word8Set
intersection

instance BoundedJoinSemiLattice Word8Set where
    bottom :: Word8Set
bottom = Word8Set
empty

instance BoundedMeetSemiLattice Word8Set where
    top :: Word8Set
top = Word8Set
full

-------------------------------------------------------------------------------
-- Construction
-------------------------------------------------------------------------------

empty :: Word8Set
empty :: Word8Set
empty = Word256 -> Word8Set
W8S forall a. Bits a => a
Bits.zeroBits

full :: Word8Set
full :: Word8Set
full = Word256 -> Word8Set
W8S Word256
ones

ones :: Word256
ones :: Word256
ones = forall a. Bits a => a -> a
Bits.complement forall a. Bits a => a
Bits.zeroBits

singleton :: Word8 -> Word8Set
singleton :: Word8 -> Word8Set
singleton Word8
x = Word256 -> Word8Set
W8S (forall a. Bits a => Int -> a
Bits.bit (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x))

range :: Word8 -> Word8 -> Word8Set
range :: Word8 -> Word8 -> Word8Set
range Word8
mi Word8
ma
    | Word8
mi forall a. Ord a => a -> a -> Bool
<= Word8
ma  = Word256 -> Word8Set
W8S forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> a
Bits.shiftL (forall a. Bits a => a -> Int -> a
Bits.shiftR Word256
ones (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Num a => a -> a
negate (Word8
1 forall a. Num a => a -> a -> a
+ Word8
ma forall a. Num a => a -> a -> a
- Word8
mi)))) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
mi)
    | Bool
otherwise = Word8Set
empty

-------------------------------------------------------------------------------
-- Insertion
-------------------------------------------------------------------------------

insert :: Word8 -> Word8Set -> Word8Set
insert :: Word8 -> Word8Set -> Word8Set
insert Word8
x (W8S Word256
xs) = Word256 -> Word8Set
W8S (forall a. Bits a => a -> Int -> a
Bits.setBit Word256
xs (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x))

-------------------------------------------------------------------------------
-- Deletion
-------------------------------------------------------------------------------

delete :: Word8 -> Word8Set -> Word8Set
delete :: Word8 -> Word8Set -> Word8Set
delete Word8
x (W8S Word256
xs) = Word256 -> Word8Set
W8S (forall a. Bits a => a -> Int -> a
Bits.clearBit Word256
xs (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x))

-------------------------------------------------------------------------------
-- Query
-------------------------------------------------------------------------------

null :: Word8Set -> Bool
null :: Word8Set -> Bool
null (W8S Word256
xs) = Word256
xs forall a. Eq a => a -> a -> Bool
== forall a. Bits a => a
Bits.zeroBits

isFull :: Word8Set -> Bool
isFull :: Word8Set -> Bool
isFull (W8S Word256
xs) = Word256
xs forall a. Eq a => a -> a -> Bool
== Word256
ones

size :: Word8Set -> Int
size :: Word8Set -> Int
size (W8S Word256
xs) = forall a. Bits a => a -> Int
Bits.popCount Word256
xs

member :: Word8 -> Word8Set -> Bool
member :: Word8 -> Word8Set -> Bool
member Word8
x (W8S Word256
xs) = forall a. Bits a => a -> Int -> Bool
Bits.testBit Word256
xs (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x)

-- | Optimized routing to check membership when 'Word8Set' is statically known.
--
-- @
-- 'memberCode' c ws = [||'member' $$c $$(liftTyped ws) ||]
-- @
--
memberCode :: Code Q Word8 -> Word8Set -> Code Q Bool
memberCode :: Code Q Word8 -> Word8Set -> Code Q Bool
memberCode Code Q Word8
c Word8Set
ws
    -- simple cases
    | Word8Set -> Bool
null Word8Set
ws                     = [|| False ||]
    | Word8Set -> Bool
isFull Word8Set
ws                   = [|| True ||]
    | Word8Set -> Int
size Word8Set
ws forall a. Eq a => a -> a -> Bool
== Int
1                = [|| $$c == $$(liftTyped (findMin ws)) ||]
    | Word8Set -> Int
size Word8Set
ws forall a. Eq a => a -> a -> Bool
== Int
2                = [|| $$c == $$(liftTyped (findMin ws)) || $$c == $$(liftTyped (findMax ws)) ||]

    -- continuos range
    | Word8Set -> Bool
isSingleRange Word8Set
ws            = [|| $$(liftTyped (findMin ws)) <= $$c && $$c <= $$(liftTyped (findMax ws)) ||]

    -- low chars
    | W8S (Word256 Word64
0 Word64
0 Word64
0 Word64
x) <- Word8Set
ws = [|| $$c < 64 && Bits.testBit ($$(liftTyped x) :: Word64) (fromIntegral ($$c :: Word8)) ||]

    -- fallback
    | Bool
otherwise                   = [|| member $$c $$(liftTyped ws) ||]

isSubsetOf :: Word8Set -> Word8Set -> Bool
isSubsetOf :: Word8Set -> Word8Set -> Bool
isSubsetOf Word8Set
a Word8Set
b = Word8Set
b forall a. Eq a => a -> a -> Bool
== Word8Set -> Word8Set -> Word8Set
union Word8Set
a Word8Set
b

isSingleRange :: Word8Set -> Bool
isSingleRange :: Word8Set -> Bool
isSingleRange (W8S Word256
0)  = Bool
True
isSingleRange (W8S Word256
ws) = forall a. Bits a => a -> Int
Bits.popCount Word256
ws forall a. Num a => a -> a -> a
+ forall b. FiniteBits b => b -> Int
Bits.countLeadingZeros Word256
ws forall a. Num a => a -> a -> a
+ forall b. FiniteBits b => b -> Int
Bits.countTrailingZeros Word256
ws forall a. Eq a => a -> a -> Bool
== Int
256

-------------------------------------------------------------------------------
-- Combine
-------------------------------------------------------------------------------

complement :: Word8Set -> Word8Set
complement :: Word8Set -> Word8Set
complement (W8S Word256
xs) = Word256 -> Word8Set
W8S (forall a. Bits a => a -> a
Bits.complement Word256
xs)

union :: Word8Set -> Word8Set -> Word8Set
union :: Word8Set -> Word8Set -> Word8Set
union (W8S Word256
xs) (W8S Word256
ys) = Word256 -> Word8Set
W8S (Word256
xs forall a. Bits a => a -> a -> a
.|. Word256
ys)

intersection :: Word8Set -> Word8Set -> Word8Set
intersection :: Word8Set -> Word8Set -> Word8Set
intersection (W8S Word256
xs) (W8S Word256
ys) = Word256 -> Word8Set
W8S (Word256
xs forall a. Bits a => a -> a -> a
.&. Word256
ys)

-------------------------------------------------------------------------------
-- Min/Max
-------------------------------------------------------------------------------

findMin :: Word8Set -> Word8
findMin :: Word8Set -> Word8
findMin (W8S Word256
xs) = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall b. FiniteBits b => b -> Int
Bits.countTrailingZeros Word256
xs)

findMax :: Word8Set -> Word8
findMax :: Word8Set -> Word8
findMax (W8S Word256
xs) = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
255 forall a. Num a => a -> a -> a
- forall b. FiniteBits b => b -> Int
Bits.countLeadingZeros Word256
xs)

-------------------------------------------------------------------------------
-- List
-------------------------------------------------------------------------------

elems :: Word8Set -> [Word8]
elems :: Word8Set -> [Word8]
elems = Word8Set -> [Word8]
toList

toList :: Word8Set -> [Word8]
toList :: Word8Set -> [Word8]
toList Word8Set
xs = [ Word8
w8 | Word8
w8 <- [Word8
0x00..Word8
0xff], Word8 -> Word8Set -> Bool
member Word8
w8 Word8Set
xs]

fromList :: [Word8] -> Word8Set
fromList :: [Word8] -> Word8Set
fromList [Word8]
w8s = Word256 -> Word8Set
W8S forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Word256
acc Word8
w8 -> forall a. Bits a => a -> Int -> a
Bits.setBit Word256
acc (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8)) forall a. Bits a => a
Bits.zeroBits [Word8]
w8s