module Generic.Random.Internal.Generic where
import Control.Applicative
import Data.Coerce
import Data.Proxy
import GHC.Generics hiding (S, Arity)
import GHC.TypeLits
import Test.QuickCheck
genericArbitrary
:: forall a
. (Generic a, GA Unsized (Rep a))
=> Weights a
-> Gen a
genericArbitrary (Weights w n) = (unGen' . fmap to) (ga w n :: Gen' Unsized (Rep a p))
genericArbitraryU
:: forall a
. (Generic a, GA Unsized (Rep a), UniformWeight (Weights_ (Rep a)))
=> Gen a
genericArbitraryU = genericArbitrary uniform
genericArbitrary'
:: forall n a
. (Generic a, GA (Sized n) (Rep a))
=> n
-> Weights a
-> Gen a
genericArbitrary' _ (Weights w n) =
(unGen' . fmap to) (ga w n :: Gen' (Sized n) (Rep a p))
genericArbitraryU0
:: forall n a
. (Generic a, GA (Sized Z) (Rep a), UniformWeight (Weights_ (Rep a)))
=> Gen a
genericArbitraryU0 = genericArbitrary' Z uniform
genericArbitraryU1
:: forall n a
. (Generic a, GA (Sized (S Z)) (Rep a), UniformWeight (Weights_ (Rep a)))
=> Gen a
genericArbitraryU1 = genericArbitrary' (S Z) uniform
type family Weights_ (f :: * -> *) :: * where
Weights_ (f :+: g) = Weights_ f :| Weights_ g
Weights_ (M1 D _c f) = Weights_ f
#if __GLASGOW_HASKELL__ >= 800
Weights_ (M1 C ('MetaCons c _i _j) _f) = L c
#else
Weights_ (M1 C _c _f) = ()
#endif
data a :| b = N a Int b
data L (c :: Symbol) = L
data Weights a = Weights (Weights_ (Rep a)) Int
newtype W (c :: Symbol) = W Int deriving Num
weights :: (Weights_ (Rep a), Int, ()) -> Weights a
weights (w, n, ()) = Weights w n
uniform :: UniformWeight (Weights_ (Rep a)) => Weights a
uniform =
let (w, n) = uniformWeight
in Weights w n
type family First a :: Symbol where
First (a :| _b) = First a
First (L c) = c
class WeightBuilder a where
type Prec a r
(%) :: W (First a) -> Prec a r -> (a, Int, r)
infixr 1 %
instance WeightBuilder a => WeightBuilder (a :| b) where
type Prec (a :| b) r = Prec a (b, Int, r)
m % prec =
let (a, n, (b, p, r)) = m % prec
in (N a n b, n + p, r)
instance WeightBuilder (L c) where
type Prec (L c) r = r
W m % prec = (L, m, prec)
instance WeightBuilder () where
type Prec () r = r
W m % prec = ((), m, prec)
class UniformWeight a where
uniformWeight :: (a, Int)
instance (UniformWeight a, UniformWeight b) => UniformWeight (a :| b) where
uniformWeight =
let
(a, m) = uniformWeight
(b, n) = uniformWeight
in
(N a m b, m + n)
instance UniformWeight (L c) where
uniformWeight = (L, 1)
instance UniformWeight () where
uniformWeight = ((), 1)
newtype Gen' sized a = Gen' { unGen' :: Gen a }
deriving (Functor, Applicative, Monad)
data Sized n
data Unsized
sized' :: (Int -> Gen' sized a) -> Gen' sized a
sized' g = Gen' . sized $ \sz -> unGen' (g sz)
class GA sized f where
ga :: Weights_ f -> Int -> Gen' sized (f p)
instance GA sized f => GA sized (M1 D c f) where
ga w n = fmap M1 (ga w n)
instance GAProduct f => GA Unsized (M1 C c f) where
ga _ _ = (Gen' . fmap M1) gaProduct
instance (GAProduct f, KnownNat (Arity f)) => GA (Sized n) (M1 C c f) where
ga _ _ = Gen' (sized $ \n -> resize (n `div` arity) gaProduct)
where
arity = fromInteger (natVal (Proxy :: Proxy (Arity f)))
instance (GASum (Sized n) f, GASum (Sized n) g, BaseCases n f, BaseCases n g)
=> GA (Sized n) (f :+: g) where
ga w n = sized' $ \sz ->
case unTagged (baseCases w n :: Tagged n (Weighted ((f :+: g) p))) of
Weighted (Just (bc, n)) | sz == 0 -> Gen' (choose (0, n 1) >>= bc)
_ -> gaSum' w n
instance (GASum Unsized f, GASum Unsized g) => GA Unsized (f :+: g) where
ga = gaSum'
gArbitrarySingle
:: forall sized f p c0
. (GA sized f, Weights_ f ~ L c0)
=> Gen' sized (f p)
gArbitrarySingle = ga L 0
gaSum' :: GASum sized f => Weights_ f -> Int -> Gen' sized (f p)
gaSum' w n = do
i <- Gen' $ choose (0, n1)
gaSum i w
class GASum sized f where
gaSum :: Int -> Weights_ f -> Gen' sized (f p)
instance (GASum sized f, GASum sized g) => GASum sized (f :+: g) where
gaSum i (N a n b)
| i < n = fmap L1 (gaSum i a)
| otherwise = fmap R1 (gaSum (i n) b)
instance GAProduct f => GASum sized (M1 i c f) where
gaSum _ _ = Gen' gaProduct
class GAProduct f where
gaProduct :: Gen (f p)
instance GAProduct U1 where
gaProduct = pure U1
instance Arbitrary c => GAProduct (K1 i c) where
gaProduct = fmap K1 arbitrary
instance GAProduct f => GAProduct (M1 i c f) where
gaProduct = fmap M1 gaProduct
instance (GAProduct f, GAProduct g) => GAProduct (f :*: g) where
gaProduct = liftA2 (:*:) gaProduct gaProduct
type family Arity f :: Nat where
Arity (f :*: g) = Arity f + Arity g
Arity (M1 _i _c _f) = 1
newtype Tagged a b = Tagged { unTagged :: b }
deriving Functor
data Z = Z
data S n = S n
newtype Weighted a = Weighted (Maybe (Int -> Gen a, Int))
deriving Functor
instance Applicative Weighted where
pure a = Weighted (Just ((pure . pure) a, 1))
Weighted f <*> Weighted a = Weighted $ liftA2 g f a
where
g (f, m) (a, n) =
( \i ->
let (j, k) = i `divMod` m
in f j <*> a k
, m * n )
instance Alternative Weighted where
empty = Weighted Nothing
a <|> Weighted Nothing = a
Weighted Nothing <|> b = b
Weighted (Just (a, m)) <|> Weighted (Just (b, n)) = Weighted . Just $
( \i ->
if i < m then
a i
else
b (i m)
, m + n )
class BaseCases n f where
baseCases :: Weights_ f -> Int -> Tagged n (Weighted (f p))
instance (BaseCases n f, BaseCases n g) => BaseCases n (f :+: g) where
baseCases (N a m b) n =
concat
((fmap . fmap) L1 (baseCases a m))
((fmap . fmap) R1 (baseCases b (n m)))
where
concat :: Alternative u => Tagged n (u a) -> Tagged n (u a) -> Tagged n (u a)
concat (Tagged a) (Tagged b) = Tagged (a <|> b)
instance ListBaseCases n f => BaseCases n (M1 i c f) where
baseCases _ n = fmap reweigh listBaseCases
where
reweigh :: Weighted a -> Weighted a
reweigh (Weighted h) = Weighted (fmap (\(g, _) -> (g, n)) h)
class ListBaseCases n f where
listBaseCases :: Alternative u => Tagged n (u (f p))
type BaseCases' n a = (Generic a, ListBaseCases n (Rep a))
instance ListBaseCases n U1 where
listBaseCases = Tagged (pure U1)
instance ListBaseCases n f => ListBaseCases n (M1 i c f) where
listBaseCases = (fmap . fmap) M1 listBaseCases
instance ListBaseCases Z (K1 i c) where
listBaseCases = Tagged empty
instance (Generic c, ListBaseCases n (Rep c)) => ListBaseCases (S n) (K1 i c) where
listBaseCases = (retag . (fmap . fmap) (K1 . to)) listBaseCases
where
retag :: Tagged n a -> Tagged (S n) a
retag = coerce
instance (ListBaseCases n f, ListBaseCases n g) => ListBaseCases n (f :+: g) where
listBaseCases =
concat
((fmap . fmap) L1 listBaseCases)
((fmap . fmap) R1 listBaseCases)
where
concat :: Alternative u => Tagged n (u a) -> Tagged n (u a) -> Tagged n (u a)
concat (Tagged a) (Tagged b) = Tagged (a <|> b)
instance (ListBaseCases n f, ListBaseCases n g) => ListBaseCases n (f :*: g) where
listBaseCases = liftedP listBaseCases listBaseCases
where
liftedP
:: Applicative u
=> Tagged n (u (f p))
-> Tagged n (u (g p))
-> Tagged n (u ((f :*: g) p))
liftedP (Tagged f) (Tagged g) = Tagged (liftA2 (:*:) f g)