{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Strict.Containers.Serialise
  (
  ) where

import           Codec.CBOR.Decoding
import           Codec.CBOR.Encoding
import           Codec.Serialise.Class
import           Data.Hashable (Hashable)
import           Data.Semigroup (Semigroup (..)) -- helps with compatibility

import qualified Data.Foldable as Foldable
import qualified Data.Strict.HashMap as HashMap
import qualified Data.Strict.IntMap as IntMap
import qualified Data.Strict.Map as Map
import qualified Data.Strict.Sequence as Sequence
import qualified Data.Strict.Vector as Vector


-- code copied from serialise

decodeContainerSkelWithReplicate
  :: (Serialise a)
  => Decoder s Int
     -- ^ How to get the size of the container
  -> (Int -> Decoder s a -> Decoder s container)
     -- ^ replicateM for the container
  -> ([container] -> container)
     -- ^ concat for the container
  -> Decoder s container
decodeContainerSkelWithReplicate :: Decoder s Int
-> (Int -> Decoder s a -> Decoder s container)
-> ([container] -> container)
-> Decoder s container
decodeContainerSkelWithReplicate Decoder s Int
decodeLen Int -> Decoder s a -> Decoder s container
replicateFun [container] -> container
fromList = do
    -- Look at how much data we have at the moment and use it as the limit for
    -- the size of a single call to replicateFun. We don't want to use
    -- replicateFun directly on the result of decodeLen since this might lead to
    -- DOS attack (attacker providing a huge value for length). So if it's above
    -- our limit, we'll do manual chunking and then combine the containers into
    -- one.
    Int
size <- Decoder s Int
decodeLen
    Int
limit <- Decoder s Int
forall s. Decoder s Int
peekAvailable
    if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
limit
       then Int -> Decoder s a -> Decoder s container
replicateFun Int
size Decoder s a
forall a s. Serialise a => Decoder s a
decode
       else do
           -- Take the max of limit and a fixed chunk size (note: limit can be
           -- 0). This basically means that the attacker can make us allocate a
           -- container of size 128 even though there's no actual input.
           let chunkSize :: Int
chunkSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
limit Int
128
               (Int
d, Int
m) = Int
size Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
chunkSize
               buildOne :: Int -> Decoder s container
buildOne Int
s = Int -> Decoder s a -> Decoder s container
replicateFun Int
s Decoder s a
forall a s. Serialise a => Decoder s a
decode
           [container]
containers <- [Decoder s container] -> Decoder s [container]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([Decoder s container] -> Decoder s [container])
-> [Decoder s container] -> Decoder s [container]
forall a b. (a -> b) -> a -> b
$ Int -> Decoder s container
buildOne Int
m Decoder s container
-> [Decoder s container] -> [Decoder s container]
forall a. a -> [a] -> [a]
: Int -> Decoder s container -> [Decoder s container]
forall a. Int -> a -> [a]
replicate Int
d (Int -> Decoder s container
buildOne Int
chunkSize)
           container -> Decoder s container
forall (m :: * -> *) a. Monad m => a -> m a
return (container -> Decoder s container)
-> container -> Decoder s container
forall a b. (a -> b) -> a -> b
$! [container] -> container
fromList [container]
containers
{-# INLINE decodeContainerSkelWithReplicate #-}

instance (Ord k, Serialise k, Serialise v) => Serialise (Map.Map k v) where
  encode :: Map k v -> Encoding
encode = (Map k v -> Int)
-> ((k -> v -> Encoding -> Encoding)
    -> Encoding -> Map k v -> Encoding)
-> Map k v
-> Encoding
forall k v m.
(Serialise k, Serialise v) =>
(m -> Int)
-> ((k -> v -> Encoding -> Encoding) -> Encoding -> m -> Encoding)
-> m
-> Encoding
encodeMapSkel Map k v -> Int
forall k a. Map k a -> Int
Map.size (k -> v -> Encoding -> Encoding) -> Encoding -> Map k v -> Encoding
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey
  decode :: Decoder s (Map k v)
decode = ([(k, v)] -> Map k v) -> Decoder s (Map k v)
forall k v m s.
(Serialise k, Serialise v) =>
([(k, v)] -> m) -> Decoder s m
decodeMapSkel [(k, v)] -> Map k v
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList

instance (Serialise k, Hashable k, Eq k, Serialise v) =>
  Serialise (HashMap.HashMap k v) where
  encode :: HashMap k v -> Encoding
encode = (HashMap k v -> Int)
-> ((k -> v -> Encoding -> Encoding)
    -> Encoding -> HashMap k v -> Encoding)
-> HashMap k v
-> Encoding
forall k v m.
(Serialise k, Serialise v) =>
(m -> Int)
-> ((k -> v -> Encoding -> Encoding) -> Encoding -> m -> Encoding)
-> m
-> Encoding
encodeMapSkel HashMap k v -> Int
forall k v. HashMap k v -> Int
HashMap.size (k -> v -> Encoding -> Encoding)
-> Encoding -> HashMap k v -> Encoding
forall k v a. (k -> v -> a -> a) -> a -> HashMap k v -> a
HashMap.foldrWithKey
  decode :: Decoder s (HashMap k v)
decode = ([(k, v)] -> HashMap k v) -> Decoder s (HashMap k v)
forall k v m s.
(Serialise k, Serialise v) =>
([(k, v)] -> m) -> Decoder s m
decodeMapSkel [(k, v)] -> HashMap k v
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HashMap.fromList

instance (Serialise a) => Serialise (IntMap.IntMap a) where
  encode :: IntMap a -> Encoding
encode = (IntMap a -> Int)
-> ((Int -> a -> Encoding -> Encoding)
    -> Encoding -> IntMap a -> Encoding)
-> IntMap a
-> Encoding
forall k v m.
(Serialise k, Serialise v) =>
(m -> Int)
-> ((k -> v -> Encoding -> Encoding) -> Encoding -> m -> Encoding)
-> m
-> Encoding
encodeMapSkel IntMap a -> Int
forall a. IntMap a -> Int
IntMap.size (Int -> a -> Encoding -> Encoding)
-> Encoding -> IntMap a -> Encoding
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey
  decode :: Decoder s (IntMap a)
decode = ([(Int, a)] -> IntMap a) -> Decoder s (IntMap a)
forall k v m s.
(Serialise k, Serialise v) =>
([(k, v)] -> m) -> Decoder s m
decodeMapSkel [(Int, a)] -> IntMap a
forall a. [(Int, a)] -> IntMap a
IntMap.fromList

instance (Serialise a) => Serialise (Sequence.Seq a) where
  encode :: Seq a -> Encoding
encode = (Word -> Encoding)
-> (Seq a -> Int)
-> ((a -> Encoding -> Encoding) -> Encoding -> Seq a -> Encoding)
-> (a -> Encoding -> Encoding)
-> Seq a
-> Encoding
forall container accumFunc.
(Word -> Encoding)
-> (container -> Int)
-> (accumFunc -> Encoding -> container -> Encoding)
-> accumFunc
-> container
-> Encoding
encodeContainerSkel
             Word -> Encoding
encodeListLen
             Seq a -> Int
forall a. Seq a -> Int
Sequence.length
             (a -> Encoding -> Encoding) -> Encoding -> Seq a -> Encoding
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
Foldable.foldr
             (\a
a Encoding
b -> a -> Encoding
forall a. Serialise a => a -> Encoding
encode a
a Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> Encoding
b)
  decode :: Decoder s (Seq a)
decode = Decoder s Int
-> (Int -> Decoder s a -> Decoder s (Seq a))
-> ([Seq a] -> Seq a)
-> Decoder s (Seq a)
forall a s container.
Serialise a =>
Decoder s Int
-> (Int -> Decoder s a -> Decoder s container)
-> ([container] -> container)
-> Decoder s container
decodeContainerSkelWithReplicate
             Decoder s Int
forall s. Decoder s Int
decodeListLen
             Int -> Decoder s a -> Decoder s (Seq a)
forall (m :: * -> *) a. Applicative m => Int -> m a -> m (Seq a)
Sequence.replicateM
             [Seq a] -> Seq a
forall a. Monoid a => [a] -> a
mconcat

instance (Serialise a) => Serialise (Vector.Vector a) where
  encode :: Vector a -> Encoding
encode = Vector a -> Encoding
forall a (v :: * -> *).
(Serialise a, Vector v a) =>
v a -> Encoding
encodeVector
  {-# INLINE encode #-}
  decode :: Decoder s (Vector a)
decode = Decoder s (Vector a)
forall a (v :: * -> *) s.
(Serialise a, Vector v a) =>
Decoder s (v a)
decodeVector
  {-# INLINE decode #-}