{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Strict.Containers.Serialise
  (
  ) where

import           Codec.CBOR.Decoding
import           Codec.CBOR.Encoding
import           Codec.Serialise.Class
import           Data.Hashable (Hashable)
import           Data.Semigroup (Semigroup (..)) -- helps with compatibility

import qualified Data.Foldable as Foldable
import qualified Data.Strict.HashMap as HashMap
import qualified Data.Strict.IntMap as IntMap
import qualified Data.Strict.Map as Map
import qualified Data.Strict.Sequence as Sequence
import qualified Data.Strict.Vector as Vector


-- code copied from serialise

decodeContainerSkelWithReplicate
  :: (Serialise a)
  => Decoder s Int
     -- ^ How to get the size of the container
  -> (Int -> Decoder s a -> Decoder s container)
     -- ^ replicateM for the container
  -> ([container] -> container)
     -- ^ concat for the container
  -> Decoder s container
decodeContainerSkelWithReplicate :: forall a s container.
Serialise a =>
Decoder s Int
-> (Int -> Decoder s a -> Decoder s container)
-> ([container] -> container)
-> Decoder s container
decodeContainerSkelWithReplicate Decoder s Int
decodeLen Int -> Decoder s a -> Decoder s container
replicateFun [container] -> container
fromList = do
    -- Look at how much data we have at the moment and use it as the limit for
    -- the size of a single call to replicateFun. We don't want to use
    -- replicateFun directly on the result of decodeLen since this might lead to
    -- DOS attack (attacker providing a huge value for length). So if it's above
    -- our limit, we'll do manual chunking and then combine the containers into
    -- one.
    Int
size <- Decoder s Int
decodeLen
    Int
limit <- forall s. Decoder s Int
peekAvailable
    if Int
size forall a. Ord a => a -> a -> Bool
<= Int
limit
       then Int -> Decoder s a -> Decoder s container
replicateFun Int
size forall a s. Serialise a => Decoder s a
decode
       else do
           -- Take the max of limit and a fixed chunk size (note: limit can be
           -- 0). This basically means that the attacker can make us allocate a
           -- container of size 128 even though there's no actual input.
           let chunkSize :: Int
chunkSize = forall a. Ord a => a -> a -> a
max Int
limit Int
128
               (Int
d, Int
m) = Int
size forall a. Integral a => a -> a -> (a, a)
`divMod` Int
chunkSize
               buildOne :: Int -> Decoder s container
buildOne Int
s = Int -> Decoder s a -> Decoder s container
replicateFun Int
s forall a s. Serialise a => Decoder s a
decode
           [container]
containers <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ Int -> Decoder s container
buildOne Int
m forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate Int
d (Int -> Decoder s container
buildOne Int
chunkSize)
           forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! [container] -> container
fromList [container]
containers
{-# INLINE decodeContainerSkelWithReplicate #-}

instance (Ord k, Serialise k, Serialise v) => Serialise (Map.Map k v) where
  encode :: Map k v -> Encoding
encode = forall k v m.
(Serialise k, Serialise v) =>
(m -> Int)
-> ((k -> v -> Encoding -> Encoding) -> Encoding -> m -> Encoding)
-> m
-> Encoding
encodeMapSkel forall k a. Map k a -> Int
Map.size forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey
  decode :: forall s. Decoder s (Map k v)
decode = forall k v m s.
(Serialise k, Serialise v) =>
([(k, v)] -> m) -> Decoder s m
decodeMapSkel forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList

instance (Serialise k, Hashable k, Eq k, Serialise v) =>
  Serialise (HashMap.HashMap k v) where
  encode :: HashMap k v -> Encoding
encode = forall k v m.
(Serialise k, Serialise v) =>
(m -> Int)
-> ((k -> v -> Encoding -> Encoding) -> Encoding -> m -> Encoding)
-> m
-> Encoding
encodeMapSkel forall k v. HashMap k v -> Int
HashMap.size forall k v a. (k -> v -> a -> a) -> a -> HashMap k v -> a
HashMap.foldrWithKey
  decode :: forall s. Decoder s (HashMap k v)
decode = forall k v m s.
(Serialise k, Serialise v) =>
([(k, v)] -> m) -> Decoder s m
decodeMapSkel forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HashMap.fromList

instance (Serialise a) => Serialise (IntMap.IntMap a) where
  encode :: IntMap a -> Encoding
encode = forall k v m.
(Serialise k, Serialise v) =>
(m -> Int)
-> ((k -> v -> Encoding -> Encoding) -> Encoding -> m -> Encoding)
-> m
-> Encoding
encodeMapSkel forall a. IntMap a -> Int
IntMap.size forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey
  decode :: forall s. Decoder s (IntMap a)
decode = forall k v m s.
(Serialise k, Serialise v) =>
([(k, v)] -> m) -> Decoder s m
decodeMapSkel forall a. [(Int, a)] -> IntMap a
IntMap.fromList

instance (Serialise a) => Serialise (Sequence.Seq a) where
  encode :: Seq a -> Encoding
encode = forall container accumFunc.
(Word -> Encoding)
-> (container -> Int)
-> (accumFunc -> Encoding -> container -> Encoding)
-> accumFunc
-> container
-> Encoding
encodeContainerSkel
             Word -> Encoding
encodeListLen
             forall a. Seq a -> Int
Sequence.length
             forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
Foldable.foldr
             (\a
a Encoding
b -> forall a. Serialise a => a -> Encoding
encode a
a forall a. Semigroup a => a -> a -> a
<> Encoding
b)
  decode :: forall s. Decoder s (Seq a)
decode = forall a s container.
Serialise a =>
Decoder s Int
-> (Int -> Decoder s a -> Decoder s container)
-> ([container] -> container)
-> Decoder s container
decodeContainerSkelWithReplicate
             forall s. Decoder s Int
decodeListLen
             forall (m :: * -> *) a. Applicative m => Int -> m a -> m (Seq a)
Sequence.replicateM
             forall a. Monoid a => [a] -> a
mconcat

instance (Serialise a) => Serialise (Vector.Vector a) where
  encode :: Vector a -> Encoding
encode = forall a (v :: * -> *).
(Serialise a, Vector v a) =>
v a -> Encoding
encodeVector
  {-# INLINE encode #-}
  decode :: forall s. Decoder s (Vector a)
decode = forall a (v :: * -> *) s.
(Serialise a, Vector v a) =>
Decoder s (v a)
decodeVector
  {-# INLINE decode #-}