-- Copyright 2016 Google Inc. All Rights Reserved.
--
-- Use of this source code is governed by a BSD-style
-- license that can be found in the LICENSE file or at
-- https://developers.google.com/open-source/licenses/bsd

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
-- | An Arbitrary instance for protocol buffer Messages to use with QuickCheck.
module Data.ProtoLens.Arbitrary
    ( ArbitraryMessage(..)
    , arbitraryMessage
    ) where

import Data.ProtoLens.Message

import Control.Arrow ((&&&))
import Control.Monad (foldM)
import qualified Data.ByteString as BS
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (isJust, fromJust)
import qualified Data.Text as T
import Lens.Family2 (Lens', view, set)
import Lens.Family2.Unchecked (lens)
import Test.QuickCheck (Arbitrary(..), Gen, suchThat, frequency, listOf,
                        shrinkList)


-- | A newtype wrapper that provides an Arbitrary instance for the underlying
-- message.
newtype ArbitraryMessage a =
    ArbitraryMessage { unArbitraryMessage :: a } deriving (Eq, Show, Functor)

instance Message a => Arbitrary (ArbitraryMessage a) where
    arbitrary = ArbitraryMessage <$> arbitraryMessage
    shrink (ArbitraryMessage a) = ArbitraryMessage <$> shrinkMessage a

arbitraryMessage :: Message a => Gen a
arbitraryMessage = foldM (flip arbitraryField) def fields
  where
    fields = M.elems (fieldsByTag descriptor)

-- | Imitation of the (Arbitrary a => Arbitrary (Maybe a)) instance from
-- QuickCheck.
maybeGen :: Gen a -> Gen (Maybe a)
maybeGen gen = frequency [ (1, pure Nothing), (3, Just <$> gen) ]

mapGen :: (Ord key, Message entry) => Lens' entry key -> Lens' entry value ->
          Gen entry -> Gen (Map key value)
mapGen keyLens valueLens entryGen =
    mapEntriesLens keyLens valueLens (const $ listOf entryGen) M.empty

setGen :: Lens' msg a -> Gen a -> msg -> Gen msg
setGen l gen = l (const gen)

arbitraryField :: FieldDescriptor msg -> msg -> Gen msg
arbitraryField (FieldDescriptor _ ftd fa) = case fa of
    PlainField _ l -> setGen l fieldGen
    OptionalField l -> setGen l (maybeGen fieldGen)
    RepeatedField _ l -> setGen l (listOf fieldGen)
    MapField keyLens valueLens mapLens ->
        setGen mapLens (mapGen keyLens valueLens fieldGen)
  where
    fieldGen = arbitraryFieldValue ftd

arbitraryFieldValue :: FieldTypeDescriptor value -> Gen value
arbitraryFieldValue ftd = case ftd of
    MessageField -> arbitraryMessage
    GroupField -> arbitraryMessage
    -- For enum fields, all we know is that the value is an instance of
    -- MessageEnum, meaning we can only use fromEnum, toEnum, or maybeToEnum. So
    -- we must rely on the instance of Arbitrary for Int and filter out only the
    -- cases that can actually be converted to one of the enum values.
    --
    -- 'fromJust' is okay here because 'suchThat' will ensure that all generated
    -- values are 'Just _'.
    EnumField -> fromJust <$> (maybeToEnum <$> arbitrary) `suchThat` isJust
    Int32Field -> arbitrary
    Int64Field -> arbitrary
    UInt32Field -> arbitrary
    UInt64Field -> arbitrary
    SInt32Field -> arbitrary
    SInt64Field -> arbitrary
    Fixed32Field -> arbitrary
    Fixed64Field -> arbitrary
    SFixed32Field -> arbitrary
    SFixed64Field -> arbitrary
    FloatField -> arbitrary
    DoubleField -> arbitrary
    BoolField -> arbitrary
    StringField -> T.pack <$> arbitrary
    BytesField -> BS.pack <$> arbitrary

-- | Shrink each field individually and append all shrinks together into
-- a single list.
shrinkMessage :: Message a => a -> [a]
shrinkMessage msg = concatMap (`shrinkField` msg) fields
  where
    fields = M.elems (fieldsByTag descriptor)

shrinkMaybe :: (a -> [a]) -> Maybe a -> [Maybe a]
shrinkMaybe f (Just v) = Nothing : (Just <$> f v)
shrinkMaybe _ Nothing  = []

shrinkMap :: (Ord key, Message entry) => Lens' entry key -> Lens' entry value
          -> (entry -> [entry]) -> Map key value -> [Map key value]
shrinkMap keyLens valueLens f = mapEntriesLens keyLens valueLens (shrinkList f)

shrinkField :: FieldDescriptor msg -> msg -> [msg]
shrinkField (FieldDescriptor _ ftd fa) = case fa of
    PlainField _ l -> l fieldShrinker
    OptionalField l -> l (shrinkMaybe fieldShrinker)
    RepeatedField _ l -> l (shrinkList fieldShrinker)
    MapField keyLens valueLens mapLens ->
        mapLens (shrinkMap keyLens valueLens fieldShrinker)
  where
    fieldShrinker = shrinkFieldValue ftd

shrinkFieldValue :: FieldTypeDescriptor value -> value -> [value]
shrinkFieldValue ftd = case ftd of
    MessageField -> shrinkMessage
    GroupField -> map unArbitraryMessage . shrink . ArbitraryMessage
    -- Shrink to the 0-equivalent Enum value if it's both a valid Enum value
    -- and the value isn't already 0.
    EnumField -> case maybeToEnum 0 of
        Nothing -> const []
        Just zeroVal -> \val -> case fromEnum val of
          0 -> []
          _ -> [zeroVal]
    Int32Field -> shrink
    Int64Field -> shrink
    UInt32Field -> shrink
    UInt64Field -> shrink
    SInt32Field -> shrink
    SInt64Field -> shrink
    Fixed32Field -> shrink
    Fixed64Field -> shrink
    SFixed32Field -> shrink
    SFixed64Field -> shrink
    FloatField -> shrink
    DoubleField -> shrink
    BoolField -> shrink
    StringField -> map T.pack . shrink . T.unpack
    BytesField -> map BS.pack . shrink . BS.unpack

mapToEntries :: Message entry =>
                Lens' entry key -> Lens' entry value -> Map key value -> [entry]
mapToEntries keyLens valueLens m = makeEntry <$> M.toList m
  where
    makeEntry (k, v) = (set keyLens k . set valueLens v) def

entriesToMap :: Ord key =>
                Lens' entry key -> Lens' entry value -> [entry] -> Map key value
entriesToMap keyLens valueLens entries = M.fromList kvs
  where
    kvs = (view keyLens &&& view valueLens) <$> entries

-- This isn't a true lens because it doesn't obey the lens laws. Specifically,
-- view l (set l entries) /= entries because the input list of entries may
-- contain duplicate keys that would become de-duped inside the Map. It's only
-- included here to make it easy to convert from a list of entry Messages to
-- a Map.
mapEntriesLens :: (Ord key, Message entry) =>
        Lens' entry key -> Lens' entry value -> Lens' (Map key value) [entry]
mapEntriesLens kl vl = lens (mapToEntries kl vl) (const (entriesToMap kl vl))