module Rattletrap.BitGet where

import qualified Control.Exception as Exception
import qualified Control.Monad as Monad
import qualified Data.Bits as Bits
import qualified Data.ByteString as ByteString
import qualified Data.Functor.Identity as Identity
import qualified Rattletrap.BitString as BitString
import qualified Rattletrap.ByteGet as ByteGet
import qualified Rattletrap.Exception.NotEnoughInput as NotEnoughInput
import qualified Rattletrap.Get as Get

type BitGet = Get.Get BitString.BitString Identity.Identity

toByteGet :: BitGet a -> ByteGet.ByteGet a
toByteGet :: BitGet a -> ByteGet a
toByteGet BitGet a
g = do
  ByteString
s1 <- Get ByteString Identity ByteString
forall (m :: * -> *) s. Applicative m => Get s m s
Get.get
  case Identity (Either ([String], SomeException) (BitString, a))
-> Either ([String], SomeException) (BitString, a)
forall a. Identity a -> a
Identity.runIdentity (Identity (Either ([String], SomeException) (BitString, a))
 -> Either ([String], SomeException) (BitString, a))
-> (BitString
    -> Identity (Either ([String], SomeException) (BitString, a)))
-> BitString
-> Either ([String], SomeException) (BitString, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BitGet a
-> BitString
-> Identity (Either ([String], SomeException) (BitString, a))
forall s (m :: * -> *) a.
Get s m a -> s -> m (Either ([String], SomeException) (s, a))
Get.run BitGet a
g (BitString -> Either ([String], SomeException) (BitString, a))
-> BitString -> Either ([String], SomeException) (BitString, a)
forall a b. (a -> b) -> a -> b
$ ByteString -> BitString
BitString.fromByteString ByteString
s1 of
    Left ([String]
ls, SomeException
e) -> [String] -> ByteGet a -> ByteGet a
forall (m :: * -> *) s a.
Functor m =>
[String] -> Get s m a -> Get s m a
Get.labels [String]
ls (ByteGet a -> ByteGet a) -> ByteGet a -> ByteGet a
forall a b. (a -> b) -> a -> b
$ SomeException -> ByteGet a
forall e a. Exception e => e -> ByteGet a
ByteGet.throw SomeException
e
    Right (BitString
s2, a
x) -> do
      ByteString -> Get ByteString Identity ()
forall (m :: * -> *) s. Applicative m => s -> Get s m ()
Get.put (ByteString -> Get ByteString Identity ())
-> ByteString -> Get ByteString Identity ()
forall a b. (a -> b) -> a -> b
$ BitString -> ByteString
BitString.byteString BitString
s2
      a -> ByteGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

fromByteGet :: ByteGet.ByteGet a -> Int -> BitGet a
fromByteGet :: ByteGet a -> Int -> BitGet a
fromByteGet ByteGet a
f Int
n = do
  ByteString
x <- Int -> BitGet ByteString
byteString Int
n
  ByteGet a -> ByteString -> BitGet a
forall (m :: * -> *) s a t. Monad m => Get s m a -> s -> Get t m a
Get.embed ByteGet a
f ByteString
x

bits :: Bits.Bits a => Int -> BitGet a
bits :: Int -> BitGet a
bits Int
n = do
  let
    f :: Bits.Bits a => Bool -> a -> a
    f :: Bool -> a -> a
f Bool
bit a
x = let y :: a
y = a -> Int -> a
forall a. Bits a => a -> Int -> a
Bits.shiftL a
x Int
1 in if Bool
bit then a -> Int -> a
forall a. Bits a => a -> Int -> a
Bits.setBit a
y Int
0 else a
y
  [Bool]
xs <- Int -> Get BitString Identity Bool -> Get BitString Identity [Bool]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
Monad.replicateM Int
n Get BitString Identity Bool
bool
  a -> BitGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> BitGet a) -> a -> BitGet a
forall a b. (a -> b) -> a -> b
$ (Bool -> a -> a) -> a -> [Bool] -> a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Bool -> a -> a
forall a. Bits a => Bool -> a -> a
f a
forall a. Bits a => a
Bits.zeroBits [Bool]
xs

bool :: BitGet Bool
bool :: Get BitString Identity Bool
bool = do
  BitString
s1 <- Get BitString Identity BitString
forall (m :: * -> *) s. Applicative m => Get s m s
Get.get
  case BitString -> Maybe (Bool, BitString)
BitString.pop BitString
s1 of
    Maybe (Bool, BitString)
Nothing -> NotEnoughInput -> Get BitString Identity Bool
forall e a. Exception e => e -> BitGet a
throw NotEnoughInput
NotEnoughInput.NotEnoughInput
    Just (Bool
x, BitString
s2) -> do
      BitString -> Get BitString Identity ()
forall (m :: * -> *) s. Applicative m => s -> Get s m ()
Get.put BitString
s2
      Bool -> Get BitString Identity Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
x

byteString :: Int -> BitGet ByteString.ByteString
byteString :: Int -> BitGet ByteString
byteString Int
n = ([Word8] -> ByteString)
-> Get BitString Identity [Word8] -> BitGet ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Word8] -> ByteString
ByteString.pack (Get BitString Identity [Word8] -> BitGet ByteString)
-> (Get BitString Identity Word8 -> Get BitString Identity [Word8])
-> Get BitString Identity Word8
-> BitGet ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> Get BitString Identity Word8 -> Get BitString Identity [Word8]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
Monad.replicateM Int
n (Get BitString Identity Word8 -> BitGet ByteString)
-> Get BitString Identity Word8 -> BitGet ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Get BitString Identity Word8
forall a. Bits a => Int -> BitGet a
bits Int
8

throw :: Exception.Exception e => e -> BitGet a
throw :: e -> BitGet a
throw = e -> BitGet a
forall e (m :: * -> *) s a.
(Exception e, Applicative m) =>
e -> Get s m a
Get.throw

label :: String -> BitGet a -> BitGet a
label :: String -> BitGet a -> BitGet a
label = String -> BitGet a -> BitGet a
forall (m :: * -> *) s a.
Functor m =>
String -> Get s m a -> Get s m a
Get.label