{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
module Foundation.Check.Arbitrary
    ( Arbitrary(..)
    , frequency
    , oneof
    , elements
    , between
    ) where

import           Basement.Imports
import           Foundation.Primitive
import           Basement.Nat
import           Basement.IntegralConv
import           Basement.Bounded
import           Basement.Types.OffsetSize
import qualified Basement.Types.Char7 as Char7
import           Basement.Types.Word128 (Word128(..))
import           Basement.Types.Word256 (Word256(..))
import           Foundation.Check.Gen
import           Foundation.Random
import           Foundation.Bits
import           Foundation.Collection
import           Foundation.Numerical
import           Control.Monad (replicateM)

-- | How to generate an arbitrary value for 'a'
class Arbitrary a where
    arbitrary :: Gen a

instance Arbitrary Integer where
    arbitrary = arbitraryInteger
instance Arbitrary Natural where
    arbitrary = arbitraryNatural

instance (NatWithinBound Word64 n, KnownNat n) => Arbitrary (Zn64 n) where
    arbitrary = zn64 <$> arbitrary
instance KnownNat n => Arbitrary (Zn n) where
    arbitrary = zn <$> arbitraryNatural

-- prim types
instance Arbitrary Int where
    arbitrary = int64ToInt <$> arbitraryInt64
instance Arbitrary Word where
    arbitrary = word64ToWord <$> arbitraryWord64
instance Arbitrary Word256 where
    arbitrary = Word256 <$> arbitraryWord64 <*> arbitraryWord64 <*> arbitraryWord64 <*> arbitraryWord64
instance Arbitrary Word128 where
    arbitrary = Word128 <$> arbitraryWord64 <*> arbitraryWord64
instance Arbitrary Word64 where
    arbitrary = arbitraryWord64
instance Arbitrary Word32 where
    arbitrary = integralDownsize <$> arbitraryWord64
instance Arbitrary Word16 where
    arbitrary = integralDownsize <$> arbitraryWord64
instance Arbitrary Word8 where
    arbitrary = integralDownsize <$> arbitraryWord64
instance Arbitrary Int64 where
    arbitrary = arbitraryInt64
instance Arbitrary Int32 where
    arbitrary = integralDownsize <$> arbitraryInt64
instance Arbitrary Int16 where
    arbitrary = integralDownsize <$> arbitraryInt64
instance Arbitrary Int8 where
    arbitrary = integralDownsize <$> arbitraryInt64
instance Arbitrary Char where
    arbitrary = arbitraryChar
instance Arbitrary Char7 where
    arbitrary = Char7.fromByteMask . integralDownsize <$> arbitraryWord64
instance Arbitrary (CountOf ty) where
    arbitrary = CountOf <$> arbitrary

instance Arbitrary Bool where
    arbitrary = flip testBit 0 <$> arbitraryWord64

instance Arbitrary String where
    arbitrary = genWithParams $ \params ->
        fromList <$> (genMax (genMaxSizeString params) >>= \i -> replicateM (integralCast i) arbitrary)

instance Arbitrary AsciiString where
    arbitrary = genWithParams $ \params ->
        fromList <$> (genMax (genMaxSizeString params) >>= \i -> replicateM (integralCast i) arbitrary)

instance Arbitrary Float where
    arbitrary = arbitraryF32
instance Arbitrary Double where
    arbitrary = arbitraryF64

instance Arbitrary a => Arbitrary (Maybe a) where
    arbitrary = frequency $ nonEmpty_ [ (1, pure Nothing), (4, Just <$> arbitrary) ]

instance (Arbitrary l, Arbitrary r) => Arbitrary (Either l r) where
    arbitrary = oneof $ nonEmpty_ [ Left <$> arbitrary, Right <$> arbitrary ]

instance (Arbitrary a, Arbitrary b)
    => Arbitrary (a,b) where
    arbitrary = (,) <$> arbitrary <*> arbitrary
instance (Arbitrary a, Arbitrary b, Arbitrary c)
    => Arbitrary (a,b,c) where
    arbitrary = (,,) <$> arbitrary <*> arbitrary <*> arbitrary
instance (Arbitrary a, Arbitrary b, Arbitrary c, Arbitrary d)
    => Arbitrary (a,b,c,d) where
    arbitrary = (,,,) <$> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary
instance (Arbitrary a, Arbitrary b, Arbitrary c, Arbitrary d, Arbitrary e)
    => Arbitrary (a,b,c,d,e) where
    arbitrary = (,,,,) <$> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary
instance (Arbitrary a, Arbitrary b, Arbitrary c, Arbitrary d, Arbitrary e, Arbitrary f)
    => Arbitrary (a,b,c,d,e,f) where
    arbitrary = (,,,,,) <$> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary

instance Arbitrary a => Arbitrary [a] where
    arbitrary = genWithParams $ \params ->
        fromList <$> (genMax (genMaxSizeArray params) >>= \i -> replicateM (integralCast i) arbitrary)

arbitraryInteger :: Gen Integer
arbitraryInteger =
    -- TODO use the sized parameter
    frequency $ nonEmpty_
        [ (4, integerOfSize True 2)
        , (4, integerOfSize False 2)
        , (4, integerOfSize True 4)
        , (4, integerOfSize False 4)
        , (2, integerOfSize True 8)
        , (2, integerOfSize False 8)
        , (1, integerOfSize True 16)
        , (1, integerOfSize False 16)
        ]
  where
    integerOfSize :: Bool -> Word -> Gen Integer
    integerOfSize toSign n = ((if toSign then negate else id) . foldl' (\x y -> x + integralUpsize y) 0 . toList)
                         <$> (arbitraryUArrayOf n :: Gen (UArray Word8))

arbitraryNatural :: Gen Natural
arbitraryNatural = integralDownsize . abs <$> arbitraryInteger

arbitraryChar :: Gen Char
arbitraryChar = frequency $ nonEmpty_
    [ (6, wordToChar <$> genMax 128)
    , (1, wordToChar <$> genMax 0x10ffff)
    ]

arbitraryWord64 :: Gen Word64
arbitraryWord64 = genWithRng getRandomWord64

arbitraryInt64 :: Gen Int64
arbitraryInt64 = integralCast <$> arbitraryWord64

arbitraryF64 :: Gen Double
arbitraryF64 = genWithRng getRandomF64

arbitraryF32 :: Gen Float
arbitraryF32 = genWithRng getRandomF32

arbitraryUArrayOf :: (PrimType ty, Arbitrary ty) => Word -> Gen (UArray ty)
arbitraryUArrayOf size = between (0, size) >>=
    \sz -> fromList <$> replicateM (integralCast sz) arbitrary

-- | Call one of the generator weighted
frequency :: NonEmpty [(Word, Gen a)] -> Gen a
frequency (getNonEmpty -> l) = between (0, sum) >>= pickOne l
  where
    sum :: Word
    !sum = foldl' (+) 0 $ fmap fst l

    pickOne ((k,x):xs) n
        | n <= k    = x
        | otherwise = pickOne xs (n-k)
    pickOne _ _ = error "frequency"

oneof :: NonEmpty [Gen a] -> Gen a
oneof ne = frequency (nonEmptyFmap (\x -> (1, x)) ne)

elements :: NonEmpty [a] -> Gen a
elements l = frequency (nonEmptyFmap (\x -> (1, pure x)) l)

between :: (Word, Word) -> Gen Word
between (x,y) = (+) x <$> genMax range
  where range = y - x

genMax :: Word -> Gen Word
genMax m = flip mod m <$> arbitrary