{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Bits.Coding
( Coding(..)
, getAligned, getBit, getBits, getBitsFrom
, putAligned, putUnaligned, putBit, putBits, putBitsFrom
) where
import Control.Applicative
import Control.Monad
import qualified Control.Monad.Fail as Fail
import Control.Monad.State.Class
import Control.Monad.Reader.Class
import Control.Monad.Trans
import Data.Bits
import Data.Bits.Extras
import Data.Bytes.Get
import Data.Bytes.Put
import Data.Word
{-# ANN module "hlint: ignore Redundant $!" #-}
newtype Coding m a = Coding
{ runCoding :: forall r. (a -> Int -> Word8 -> m r) -> Int -> Word8 -> m r
}
instance Functor (Coding m) where
fmap f (Coding m) = Coding $ \ k -> m (k . f)
{-# INLINE fmap #-}
instance Monad m => Applicative (Coding m) where
pure a = Coding $ \k -> k a
{-# INLINE pure #-}
(<*>) = ap
{-# INLINE (<*>) #-}
instance Monad m => Monad (Coding m) where
return = pure
{-# INLINE return #-}
Coding m >>= f = Coding $ \ k -> m $ \a -> runCoding (f a) k
{-# INLINE (>>=) #-}
#if !(MIN_VERSION_base(4,13,0))
fail e = Coding $ \_ _ _ -> fail e
{-# INLINE fail #-}
#endif
instance Fail.MonadFail m => Fail.MonadFail (Coding m) where
fail e = Coding $ \_ _ _ -> Fail.fail e
{-# INLINE fail #-}
instance (Monad m, Alternative m) => Alternative (Coding m) where
empty = Coding $ \_ _ _ -> empty
{-# INLINE empty #-}
Coding m <|> Coding n = Coding $ \k i b -> do
(a,i',b') <- m (\a i' b' -> pure (a,i',b')) i b <|> n (\a i' b' -> pure (a,i',b')) i b
k a i' b'
{-# INLINE (<|>) #-}
instance MonadPlus m => MonadPlus (Coding m) where
mzero = Coding $ \_ _ _ -> mzero
{-# INLINE mzero #-}
mplus (Coding m) (Coding n) = Coding $ \k i b -> do
(a,i',b') <- m (\a i' b' -> return (a,i',b')) i b `mplus` n (\a i' b' -> return (a,i',b')) i b
k a i' b'
{-# INLINE mplus #-}
instance MonadTrans Coding where
lift m = Coding $ \k i w -> do
a <- m
k a i w
{-# INLINE lift #-}
instance MonadState s m => MonadState s (Coding m) where
get = lift get
{-# INLINE get #-}
put = lift . put
{-# INLINE put #-}
instance MonadReader e m => MonadReader e (Coding m) where
ask = lift ask
{-# INLINE ask #-}
local f (Coding m) = Coding $ \k i b -> do
(a,i',b') <- local f $ m (\a i' b' -> return (a, i', b')) i b
k a i' b'
{-# INLINE local #-}
getAligned :: MonadGet m => m a -> Coding m a
getAligned m = Coding $ \k _ _ -> m >>= \ a -> k a 0 0
{-# INLINE getAligned #-}
getBit :: MonadGet m => Coding m Bool
getBit = Coding $ \ k i b ->
if i == 0
then getWord8 >>= \b' -> ((k $! testBit b' 7) $! 7) $! unsafeShiftL b' 1
else ((k $! testBit b 7) $! i - 1) $! unsafeShiftL b 1
{-# INLINE getBit #-}
getBits :: (MonadGet m, Bits b) => Int -> Int -> b -> Coding m b
getBits from to bits | from < to = return bits
| otherwise = do b <- getBit
getBits (pred from) to $ assignBit bits from b
{-# INLINE getBits #-}
getBitsFrom :: (MonadGet m, Bits b) => Int -> b -> Coding m b
getBitsFrom from = getBits from 0
{-# INLINE getBitsFrom #-}
instance MonadGet m => MonadGet (Coding m) where
type Remaining (Coding m) = Remaining m
type Bytes (Coding m) = Bytes m
skip = getAligned . skip
{-# INLINE skip #-}
lookAhead (Coding m) = Coding $ \k i b -> lookAhead (m k i b)
{-# INLINE lookAhead #-}
lookAheadM (Coding m) = Coding $ \k i b -> lookAheadE (m (distribute k) i b) >>= factor
where
distribute k Nothing i' b' = return $ Left $ k Nothing i' b'
distribute k (Just a) i' b' = return $ Right $ k (Just a) i' b'
factor = either id id
{-# INLINE lookAheadM #-}
lookAheadE (Coding m) = Coding $ \k i b -> lookAheadE (m (distribute k) i b) >>= factor
where
distribute k (Left e) i' b' = return $ Left $ k (Left e) i' b'
distribute k (Right a) i' b' = return $ Right $ k (Right a) i' b'
factor = either id id
{-# INLINE lookAheadE #-}
getBytes = getAligned . getBytes
{-# INLINE getBytes #-}
remaining = lift remaining
{-# INLINE remaining #-}
isEmpty = lift isEmpty
{-# INLINE isEmpty #-}
getWord8 = getAligned getWord8
{-# INLINE getWord8 #-}
getByteString = getAligned . getByteString
{-# INLINE getByteString #-}
getLazyByteString = getAligned . getLazyByteString
{-# INLINE getLazyByteString #-}
getWord16le = getAligned getWord16le
{-# INLINE getWord16le #-}
getWord32le = getAligned getWord32le
{-# INLINE getWord32le #-}
getWord64le = getAligned getWord64le
{-# INLINE getWord64le #-}
getWord16be = getAligned getWord16be
{-# INLINE getWord16be #-}
getWord32be = getAligned getWord32be
{-# INLINE getWord32be #-}
getWord64be = getAligned getWord64be
{-# INLINE getWord64be #-}
getWord16host = getAligned getWord16host
{-# INLINE getWord16host #-}
getWord32host = getAligned getWord32host
{-# INLINE getWord32host #-}
getWord64host = getAligned getWord64host
{-# INLINE getWord64host #-}
getWordhost = getAligned getWordhost
{-# INLINE getWordhost #-}
putAligned :: MonadPut m => m a -> Coding m a
putAligned m = Coding $ \ k i b ->
if i == 0
then do
a <- m
k a 0 0
else do
putWord8 b
a <- m
k a 0 0
putUnaligned :: (MonadPut m, FiniteBits b) => b -> Coding m ()
putUnaligned b = putBitsFrom (pred $ finiteBitSize b) b
{-# INLINE putUnaligned #-}
putBit :: MonadPut m => Bool -> Coding m ()
putBit v = Coding $ \k i b ->
if i == 7
then do
putWord8 (pushBit b i v)
k () 0 0
else (k () $! i + 1) $! pushBit b i v
where
pushBit w i False = clearBit w $ 7 - i
pushBit w i True = setBit w $ 7 - i
{-# INLINE putBit #-}
putBits :: (MonadPut m, Bits b) => Int -> Int -> b -> Coding m ()
putBits from to b | from < to = return ()
| otherwise = putBit (b `testBit` from) >> putBits (pred from) to b
{-# INLINE putBits #-}
putBitsFrom :: (MonadPut m, Bits b) => Int -> b -> Coding m ()
putBitsFrom from = putBits from 0
{-# INLINE putBitsFrom #-}
instance MonadPut m => MonadPut (Coding m) where
putWord8 = putAligned . putWord8
{-# INLINE putWord8 #-}
putByteString = putAligned . putByteString
{-# INLINE putByteString #-}
putLazyByteString = putAligned . putLazyByteString
{-# INLINE putLazyByteString #-}
flush = putAligned flush
{-# INLINE flush #-}
putWord16le = putAligned . putWord16le
{-# INLINE putWord16le #-}
putWord32le = putAligned . putWord32le
{-# INLINE putWord32le #-}
putWord64le = putAligned . putWord64le
{-# INLINE putWord64le #-}
putWord16be = putAligned . putWord16be
{-# INLINE putWord16be #-}
putWord32be = putAligned . putWord32be
{-# INLINE putWord32be #-}
putWord64be = putAligned . putWord64be
{-# INLINE putWord64be #-}
putWord16host = putAligned . putWord16host
{-# INLINE putWord16host #-}
putWord32host = putAligned . putWord32host
{-# INLINE putWord32host #-}
putWord64host = putAligned . putWord64host
{-# INLINE putWord64host #-}
putWordhost = putAligned . putWordhost
{-# INLINE putWordhost #-}