{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Generic.Random.Internal.Generic where

import Control.Applicative
import Data.Coerce
import Data.Proxy
import GHC.Generics hiding (S, Arity)
import GHC.TypeLits
import Test.QuickCheck

-- * Random generators

-- | Pick a constructor with a given distribution, and fill its fields
-- recursively.
genericArbitrary
  :: forall a
  .  (Generic a, GA Unsized (Rep a))
  => Weights a  -- ^ List of weights for every constructor
  -> Gen a
genericArbitrary (Weights w n) = (unGen' . fmap to) (ga w n :: Gen' Unsized (Rep a p))

-- | Shorthand for @'genericArbitrary' 'uniform'@.
genericArbitraryU
  :: forall a
  .  (Generic a, GA Unsized (Rep a), UniformWeight (Weights_ (Rep a)))
  => Gen a
genericArbitraryU = genericArbitrary uniform

-- | Like 'genericArbitrary'', with decreasing size to ensure termination for
-- recursive types, looking for base cases once the size reaches 0.
genericArbitrary'
  :: forall n a
  . (Generic a, GA (Sized n) (Rep a))
  => n
  -> Weights a  -- ^ List of weights for every constructor
  -> Gen a
genericArbitrary' _ (Weights w n) =
  (unGen' . fmap to) (ga w n :: Gen' (Sized n) (Rep a p))

-- | Shorthand for @'genericArbitrary'' 'Z' 'uniform'@, using nullary
-- constructors as the base cases.
genericArbitraryU0
  :: forall n a
  . (Generic a, GA (Sized Z) (Rep a), UniformWeight (Weights_ (Rep a)))
  => Gen a
genericArbitraryU0 = genericArbitrary' Z uniform

-- | Shorthand for @'genericArbitrary'' ('S' 'Z') 'uniform'@, using nullary
-- constructors and constructors whose fields are all nullary as base cases.
genericArbitraryU1
  :: forall n a
  . (Generic a, GA (Sized (S Z)) (Rep a), UniformWeight (Weights_ (Rep a)))
  => Gen a
genericArbitraryU1 = genericArbitrary' (S Z) uniform

-- * Internal

type family Weights_ (f :: * -> *) :: * where
  Weights_ (f :+: g) = Weights_ f :| Weights_ g
  Weights_ (M1 D _c f) = Weights_ f
#if __GLASGOW_HASKELL__ >= 800
  Weights_ (M1 C ('MetaCons c _i _j) _f) = L c
#else
  Weights_ (M1 C _c _f) = ()
#endif

data a :| b = N a Int b
data L (c :: Symbol) = L

-- | Trees of weights assigned to constructors of type @a@,
-- rescaled to obtain a probability distribution.
--
-- Two ways of constructing them.
--
-- @
-- 'weights' (x1 '%' x2 '%' ... '%' xn '%' ()) :: 'Weights' a
-- 'uniform' :: 'Weights' a
-- @
--
-- Using @weights@, there must be exactly as many weights as
-- there are constructors.
--
-- 'uniform' is equivalent to @'weights' (1 '%' ... '%' 1 '%' ())@
-- (automatically fills out the right number of 1s).
data Weights a = Weights (Weights_ (Rep a)) Int

-- | Type of a single weight, tagged with the name of the associated
-- constructor for additional compile-time checking.
--
-- @
-- 'weights' ((9 :: 'W' \"Leaf\") '%' (8 :: 'W' \"Node\") '%' ())
-- @
newtype W (c :: Symbol) = W Int deriving Num

-- | A smart constructor to specify a custom distribution.
weights :: (Weights_ (Rep a), Int, ()) -> Weights a
weights (w, n, ()) = Weights w n

-- | Uniform distribution.
uniform :: UniformWeight (Weights_ (Rep a)) => Weights a
uniform =
  let (w, n) = uniformWeight
  in Weights w n

type family First a :: Symbol where
  First (a :| _b) = First a
  First (L c) = c

class WeightBuilder a where
  type Prec a r

  -- | A binary constructor for building up trees of weights.
  (%) :: W (First a) -> Prec a r -> (a, Int, r)

infixr 1 %

instance WeightBuilder a => WeightBuilder (a :| b) where
  type Prec (a :| b) r = Prec a (b, Int, r)
  m % prec =
    let (a, n, (b, p, r)) = m % prec
    in (N a n b, n + p, r)

instance WeightBuilder (L c) where
  type Prec (L c) r = r
  W m % prec = (L, m, prec)

instance WeightBuilder () where
  type Prec () r = r
  W m % prec = ((), m, prec)

class UniformWeight a where
  uniformWeight :: (a, Int)

instance (UniformWeight a, UniformWeight b) => UniformWeight (a :| b) where
  uniformWeight =
    let
      (a, m) = uniformWeight
      (b, n) = uniformWeight
    in
      (N a m b, m + n)

instance UniformWeight (L c) where
  uniformWeight = (L, 1)

instance UniformWeight () where
  uniformWeight = ((), 1)

newtype Gen' sized a = Gen' { unGen' :: Gen a }
  deriving (Functor, Applicative, Monad)

data Sized n
data Unsized

sized' :: (Int -> Gen' sized a) -> Gen' sized a
sized' g = Gen' . sized $ \sz -> unGen' (g sz)

-- | Generic Arbitrary
class GA sized f where
  ga :: Weights_ f -> Int -> Gen' sized (f p)

instance GA sized f => GA sized (M1 D c f) where
  ga w n = fmap M1 (ga w n)

instance GAProduct f => GA Unsized (M1 C c f) where
  ga _ _ = (Gen' . fmap M1) gaProduct

instance (GAProduct f, KnownNat (Arity f)) => GA (Sized n) (M1 C c f) where
  ga _ _ = Gen' (sized $ \n -> resize (n `div` arity) gaProduct)
    where
      arity = fromInteger (natVal (Proxy :: Proxy (Arity f)))

instance (GASum (Sized n) f, GASum (Sized n) g, BaseCases n f, BaseCases n g)
  => GA (Sized n) (f :+: g) where
  ga w n = sized' $ \sz ->
    case unTagged (baseCases w n :: Tagged n (Weighted ((f :+: g) p))) of
      Weighted (Just (bc, n)) | sz == 0 -> Gen' (choose (0, n - 1) >>= bc)
      _ -> gaSum' w n

instance (GASum Unsized f, GASum Unsized g) => GA Unsized (f :+: g) where
  ga = gaSum'

gArbitrarySingle
  :: forall sized f p c0
  .  (GA sized f, Weights_ f ~ L c0)
  => Gen' sized (f p)
gArbitrarySingle = ga L 0

gaSum' :: GASum sized f => Weights_ f -> Int -> Gen' sized (f p)
gaSum' w n = do
  i <- Gen' $ choose (0, n-1)
  gaSum i w

class GASum sized f where
  gaSum :: Int -> Weights_ f -> Gen' sized (f p)

instance (GASum sized f, GASum sized g) => GASum sized (f :+: g) where
  gaSum i (N a n b)
    | i < n = fmap L1 (gaSum i a)
    | otherwise = fmap R1 (gaSum (i - n) b)

instance GAProduct f => GASum sized (M1 i c f) where
  gaSum _ _ = Gen' gaProduct


class GAProduct f where
  gaProduct :: Gen (f p)

instance GAProduct U1 where
  gaProduct = pure U1

instance Arbitrary c => GAProduct (K1 i c) where
  gaProduct = fmap K1 arbitrary

instance GAProduct f => GAProduct (M1 i c f) where
  gaProduct = fmap M1 gaProduct

instance (GAProduct f, GAProduct g) => GAProduct (f :*: g) where
  gaProduct = liftA2 (:*:) gaProduct gaProduct

type family Arity f :: Nat where
  Arity (f :*: g) = Arity f + Arity g
  Arity (M1 _i _c _f) = 1


newtype Tagged a b = Tagged { unTagged :: b }
  deriving Functor

-- $nat
-- Use the 'Z' and 'S' data types to define the depths of values used
-- by 'genericArbitrary'' to make generators terminate.

-- | Zero
data Z = Z

-- | Successor
data S n = S n

newtype Weighted a = Weighted (Maybe (Int -> Gen a, Int))
  deriving Functor

instance Applicative Weighted where
  pure a = Weighted (Just ((pure . pure) a, 1))
  Weighted f <*> Weighted a = Weighted $ liftA2 g f a
    where
      g (f, m) (a, n) =
        ( \i ->
            let (j, k) = i `divMod` m
            in f j <*> a k
        , m * n )

instance Alternative Weighted where
  empty = Weighted Nothing
  a <|> Weighted Nothing = a
  Weighted Nothing <|> b = b
  Weighted (Just (a, m)) <|> Weighted (Just (b, n)) = Weighted . Just $
    ( \i ->
        if i < m then
          a i
        else
          b (i - m)
    , m + n )

class BaseCases n f where
  baseCases :: Weights_ f -> Int -> Tagged n (Weighted (f p))

instance (BaseCases n f, BaseCases n g) => BaseCases n (f :+: g) where
  baseCases (N a m b) n =
    concat
      ((fmap . fmap) L1 (baseCases a m))
      ((fmap . fmap) R1 (baseCases b (n - m)))
    where
      concat :: Alternative u => Tagged n (u a) -> Tagged n (u a) -> Tagged n (u a)
      concat (Tagged a) (Tagged b) = Tagged (a <|> b)

instance ListBaseCases n f => BaseCases n (M1 i c f) where
  baseCases _ n = fmap reweigh listBaseCases
    where
      reweigh :: Weighted a -> Weighted a
      reweigh (Weighted h) = Weighted (fmap (\(g, _) -> (g, n)) h)

-- | A @ListBaseCases n ('Rep' a)@ constraint basically provides the list of
-- values of type @a@ with depth at most @n@.
class ListBaseCases n f where
  listBaseCases :: Alternative u => Tagged n (u (f p))

-- | For convenience.
type BaseCases' n a = (Generic a, ListBaseCases n (Rep a))

instance ListBaseCases n U1 where
  listBaseCases = Tagged (pure U1)

instance ListBaseCases n f => ListBaseCases n (M1 i c f) where
  listBaseCases = (fmap . fmap) M1 listBaseCases

instance ListBaseCases Z (K1 i c) where
  listBaseCases = Tagged empty

instance (Generic c, ListBaseCases n (Rep c)) => ListBaseCases (S n) (K1 i c) where
  listBaseCases = (retag . (fmap . fmap) (K1 . to)) listBaseCases
    where
      retag :: Tagged n a -> Tagged (S n) a
      retag = coerce

instance (ListBaseCases n f, ListBaseCases n g) => ListBaseCases n (f :+: g) where
  listBaseCases =
    concat
      ((fmap . fmap) L1 listBaseCases)
      ((fmap . fmap) R1 listBaseCases)
    where
      concat :: Alternative u => Tagged n (u a) -> Tagged n (u a) -> Tagged n (u a)
      concat (Tagged a) (Tagged b) = Tagged (a <|> b)

instance (ListBaseCases n f, ListBaseCases n g) => ListBaseCases n (f :*: g) where
  listBaseCases = liftedP listBaseCases listBaseCases
    where
      liftedP
        :: Applicative u
        => Tagged n (u (f p))
        -> Tagged n (u (g p))
        -> Tagged n (u ((f :*: g) p))
      liftedP (Tagged f) (Tagged g) = Tagged (liftA2 (:*:) f g)