module Wingman.KnownStrategies.QuickCheck where

import ConLike (ConLike(RealDataCon))
import Data.Bool (bool)
import Data.Generics (everything, mkQ)
import Data.List (partition)
import DataCon (DataCon, dataConName)
import Development.IDE.GHC.Compat (GhcPs, HsExpr, noLoc)
import GHC.Exts (IsString (fromString))
import GHC.List (foldl')
import GHC.SourceGen (int)
import GHC.SourceGen.Binds (match, valBind)
import GHC.SourceGen.Expr (case', lambda, let')
import GHC.SourceGen.Overloaded (App ((@@)), HasList (list))
import GHC.SourceGen.Pat (conP)
import OccName (HasOccName (occName), mkVarOcc, occNameString)
import Refinery.Tactic (goal, rule, failure)
import TyCon (TyCon, tyConDataCons, tyConName)
import Type (splitTyConApp_maybe)
import Wingman.CodeGen
import Wingman.Judgements (jGoal)
import Wingman.Machinery (tracePrim)
import Wingman.Types


------------------------------------------------------------------------------
-- | Known tactic for deriving @arbitrary :: Gen a@. This tactic splits the
-- type's data cons into terminal and inductive cases, and generates code that
-- produces a terminal if the QuickCheck size parameter is <=1, or any data con
-- otherwise. It correctly scales recursive parameters, ensuring termination.
deriveArbitrary :: TacticsM ()
deriveArbitrary :: TacticsM ()
deriveArbitrary = do
  CType
ty <- Judgement' CType -> CType
forall a. Judgement' a -> a
jGoal (Judgement' CType -> CType)
-> TacticT
     (Judgement' CType)
     (Synthesized (LHsExpr GhcPs))
     TacticError
     TacticState
     ExtractM
     (Judgement' CType)
-> TacticT
     (Judgement' CType)
     (Synthesized (LHsExpr GhcPs))
     TacticError
     TacticState
     ExtractM
     CType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TacticT
  (Judgement' CType)
  (Synthesized (LHsExpr GhcPs))
  TacticError
  TacticState
  ExtractM
  (Judgement' CType)
forall (m :: * -> *) jdg ext err s.
Functor m =>
TacticT jdg ext err s m jdg
goal
  case HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe (Type -> Maybe (TyCon, [Type])) -> Type -> Maybe (TyCon, [Type])
forall a b. (a -> b) -> a -> b
$ CType -> Type
unCType CType
ty of
    Just (TyCon
gen_tc, [HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe -> Just (TyCon
tc, [Type]
apps)])
        | OccName -> String
occNameString (Name -> OccName
forall name. HasOccName name => name -> OccName
occName (Name -> OccName) -> Name -> OccName
forall a b. (a -> b) -> a -> b
$ TyCon -> Name
tyConName TyCon
gen_tc) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Gen" -> do
      (Judgement' CType
 -> RuleT
      (Judgement' CType)
      (Synthesized (LHsExpr GhcPs))
      TacticError
      TacticState
      ExtractM
      (Synthesized (LHsExpr GhcPs)))
-> TacticsM ()
forall (m :: * -> *) jdg ext err s.
Monad m =>
(jdg -> RuleT jdg ext err s m ext) -> TacticT jdg ext err s m ()
rule ((Judgement' CType
  -> RuleT
       (Judgement' CType)
       (Synthesized (LHsExpr GhcPs))
       TacticError
       TacticState
       ExtractM
       (Synthesized (LHsExpr GhcPs)))
 -> TacticsM ())
-> (Judgement' CType
    -> RuleT
         (Judgement' CType)
         (Synthesized (LHsExpr GhcPs))
         TacticError
         TacticState
         ExtractM
         (Synthesized (LHsExpr GhcPs)))
-> TacticsM ()
forall a b. (a -> b) -> a -> b
$ \Judgement' CType
_ -> do
        let dcs :: [DataCon]
dcs = TyCon -> [DataCon]
tyConDataCons TyCon
tc
            ([Generator]
terminal, [Generator]
big) = (Generator -> Bool) -> [Generator] -> ([Generator], [Generator])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0) (Integer -> Bool) -> (Generator -> Integer) -> Generator -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Generator -> Integer
genRecursiveCount)
                        ([Generator] -> ([Generator], [Generator]))
-> [Generator] -> ([Generator], [Generator])
forall a b. (a -> b) -> a -> b
$ (DataCon -> Generator) -> [DataCon] -> [Generator]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TyCon -> [Type] -> DataCon -> Generator
mkGenerator TyCon
tc [Type]
apps) [DataCon]
dcs
            terminal_expr :: HsExpr GhcPs
terminal_expr = String -> HsExpr GhcPs
mkVal String
"terminal"
            oneof_expr :: HsExpr GhcPs
oneof_expr = String -> HsExpr GhcPs
mkVal String
"oneof"
        Synthesized (LHsExpr GhcPs)
-> RuleT
     (Judgement' CType)
     (Synthesized (LHsExpr GhcPs))
     TacticError
     TacticState
     ExtractM
     (Synthesized (LHsExpr GhcPs))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          (Synthesized (LHsExpr GhcPs)
 -> RuleT
      (Judgement' CType)
      (Synthesized (LHsExpr GhcPs))
      TacticError
      TacticState
      ExtractM
      (Synthesized (LHsExpr GhcPs)))
-> Synthesized (LHsExpr GhcPs)
-> RuleT
     (Judgement' CType)
     (Synthesized (LHsExpr GhcPs))
     TacticError
     TacticState
     ExtractM
     (Synthesized (LHsExpr GhcPs))
forall a b. (a -> b) -> a -> b
$ Trace
-> Hypothesis CType
-> Set OccName
-> Sum Int
-> LHsExpr GhcPs
-> Synthesized (LHsExpr GhcPs)
forall a.
Trace
-> Hypothesis CType -> Set OccName -> Sum Int -> a -> Synthesized a
Synthesized (String -> Trace
tracePrim String
"deriveArbitrary")
              -- TODO(sandy): This thing is not actually empty! We produced
              -- a bespoke binding "terminal", and a not-so-bespoke "n".
              -- But maybe it's fine for known rules?
              Hypothesis CType
forall a. Monoid a => a
mempty
              Set OccName
forall a. Monoid a => a
mempty
              Sum Int
forall a. Monoid a => a
mempty
          (LHsExpr GhcPs -> Synthesized (LHsExpr GhcPs))
-> LHsExpr GhcPs -> Synthesized (LHsExpr GhcPs)
forall a b. (a -> b) -> a -> b
$ SrcSpanLess (LHsExpr GhcPs) -> LHsExpr GhcPs
forall a. HasSrcSpan a => SrcSpanLess a -> a
noLoc (SrcSpanLess (LHsExpr GhcPs) -> LHsExpr GhcPs)
-> SrcSpanLess (LHsExpr GhcPs) -> LHsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
              [RawValBind] -> HsExpr GhcPs -> HsExpr GhcPs
let' [OccNameStr -> HsExpr GhcPs -> RawValBind
forall t. HasValBind t => OccNameStr -> HsExpr GhcPs -> t
valBind (String -> OccNameStr
forall a. IsString a => String -> a
fromString String
"terminal") (HsExpr GhcPs -> RawValBind) -> HsExpr GhcPs -> RawValBind
forall a b. (a -> b) -> a -> b
$ [HsExpr GhcPs] -> HsExpr GhcPs
forall e. HasList e => [e] -> e
list ([HsExpr GhcPs] -> HsExpr GhcPs) -> [HsExpr GhcPs] -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ (Generator -> HsExpr GhcPs) -> [Generator] -> [HsExpr GhcPs]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Generator -> HsExpr GhcPs
genExpr [Generator]
terminal] (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
                HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
appDollar (String -> HsExpr GhcPs
mkFunc String
"sized") (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ [Pat'] -> HsExpr GhcPs -> HsExpr GhcPs
lambda [OccName -> Pat'
forall a. BVar a => OccName -> a
bvar' (String -> OccName
mkVarOcc String
"n")] (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
                  HsExpr GhcPs -> [RawMatch] -> HsExpr GhcPs
case' (String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
"<=" (String -> HsExpr GhcPs
mkVal String
"n") (Integer -> HsExpr GhcPs
forall e. HasLit e => Integer -> e
int Integer
1))
                    [ [Pat'] -> HsExpr GhcPs -> RawMatch
match [RdrNameStr -> [Pat'] -> Pat'
conP (String -> RdrNameStr
forall a. IsString a => String -> a
fromString String
"True") []] (HsExpr GhcPs -> RawMatch) -> HsExpr GhcPs -> RawMatch
forall a b. (a -> b) -> a -> b
$
                        HsExpr GhcPs
oneof_expr HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ HsExpr GhcPs
terminal_expr
                    , [Pat'] -> HsExpr GhcPs -> RawMatch
match [RdrNameStr -> [Pat'] -> Pat'
conP (String -> RdrNameStr
forall a. IsString a => String -> a
fromString String
"False") []] (HsExpr GhcPs -> RawMatch) -> HsExpr GhcPs -> RawMatch
forall a b. (a -> b) -> a -> b
$
                        HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
appDollar HsExpr GhcPs
oneof_expr (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
                          String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
"<>"
                            ([HsExpr GhcPs] -> HsExpr GhcPs
forall e. HasList e => [e] -> e
list ([HsExpr GhcPs] -> HsExpr GhcPs) -> [HsExpr GhcPs] -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ (Generator -> HsExpr GhcPs) -> [Generator] -> [HsExpr GhcPs]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Generator -> HsExpr GhcPs
genExpr [Generator]
big)
                            HsExpr GhcPs
terminal_expr
                    ]
    Maybe (TyCon, [Type])
_ -> TacticError -> TacticsM ()
forall err jdg ext s (m :: * -> *) a.
err -> TacticT jdg ext err s m a
failure (TacticError -> TacticsM ()) -> TacticError -> TacticsM ()
forall a b. (a -> b) -> a -> b
$ String -> CType -> TacticError
GoalMismatch String
"deriveArbitrary" CType
ty


------------------------------------------------------------------------------
-- | Helper data type for the generator of a specific data con.
data Generator = Generator
  { Generator -> Integer
genRecursiveCount :: Integer
  , Generator -> HsExpr GhcPs
genExpr           :: HsExpr GhcPs
  }


------------------------------------------------------------------------------
-- | Make a 'Generator' for a given tycon instantiated with the given @[Type]@.
mkGenerator :: TyCon -> [Type] -> DataCon -> Generator
mkGenerator :: TyCon -> [Type] -> DataCon -> Generator
mkGenerator TyCon
tc [Type]
apps DataCon
dc = do
  let dc_expr :: HsExpr GhcPs
dc_expr   = OccName -> HsExpr GhcPs
forall a. Var a => OccName -> a
var' (OccName -> HsExpr GhcPs) -> OccName -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ Name -> OccName
forall name. HasOccName name => name -> OccName
occName (Name -> OccName) -> Name -> OccName
forall a b. (a -> b) -> a -> b
$ DataCon -> Name
dataConName DataCon
dc
      args :: [Type]
args = ConLike -> [Type] -> [Type]
conLikeInstOrigArgTys' (DataCon -> ConLike
RealDataCon DataCon
dc) [Type]
apps
      num_recursive_calls :: Integer
num_recursive_calls = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Integer] -> Integer) -> [Integer] -> Integer
forall a b. (a -> b) -> a -> b
$ (Type -> Integer) -> [Type] -> [Integer]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Integer -> Integer -> Bool -> Integer
forall a. a -> a -> Bool -> a
bool Integer
0 Integer
1 (Bool -> Integer) -> (Type -> Bool) -> Type -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyCon -> Type -> Bool
doesTypeContain TyCon
tc) [Type]
args
      mkArbitrary :: Type -> HsExpr GhcPs
mkArbitrary = TyCon -> Integer -> Type -> HsExpr GhcPs
mkArbitraryCall TyCon
tc Integer
num_recursive_calls
  Integer -> HsExpr GhcPs -> Generator
Generator Integer
num_recursive_calls (HsExpr GhcPs -> Generator) -> HsExpr GhcPs -> Generator
forall a b. (a -> b) -> a -> b
$ case [Type]
args of
    []  -> String -> HsExpr GhcPs
mkFunc String
"pure" HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ HsExpr GhcPs
dc_expr
    (Type
a : [Type]
as) ->
      (HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs)
-> HsExpr GhcPs -> [HsExpr GhcPs] -> HsExpr GhcPs
forall a b. (b -> a -> b) -> b -> [a] -> b
foldl'
        (String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
"<*>")
        (String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
"<$>" HsExpr GhcPs
dc_expr (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ Type -> HsExpr GhcPs
mkArbitrary Type
a)
        ((Type -> HsExpr GhcPs) -> [Type] -> [HsExpr GhcPs]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> HsExpr GhcPs
mkArbitrary [Type]
as)


------------------------------------------------------------------------------
-- | Check if the given 'TyCon' exists anywhere in the 'Type'.
doesTypeContain :: TyCon -> Type -> Bool
doesTypeContain :: TyCon -> Type -> Bool
doesTypeContain TyCon
recursive_tc =
  (Bool -> Bool -> Bool) -> GenericQ Bool -> GenericQ Bool
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Bool -> Bool -> Bool
(||) (GenericQ Bool -> GenericQ Bool) -> GenericQ Bool -> GenericQ Bool
forall a b. (a -> b) -> a -> b
$ Bool -> (TyCon -> Bool) -> a -> Bool
forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ Bool
False (TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
recursive_tc)


------------------------------------------------------------------------------
-- | Generate the correct sort of call to @arbitrary@. For recursive calls, we
-- need to scale down the size parameter, either by a constant factor of 1 if
-- it's the only recursive parameter, or by @`div` n@ where n is the number of
-- recursive parameters. For all other types, just call @arbitrary@ directly.
mkArbitraryCall :: TyCon -> Integer -> Type -> HsExpr GhcPs
mkArbitraryCall :: TyCon -> Integer -> Type -> HsExpr GhcPs
mkArbitraryCall TyCon
recursive_tc Integer
n Type
ty =
  let arbitrary :: HsExpr GhcPs
arbitrary = String -> HsExpr GhcPs
mkFunc String
"arbitrary"
   in case TyCon -> Type -> Bool
doesTypeContain TyCon
recursive_tc Type
ty of
        Bool
True ->
          String -> HsExpr GhcPs
mkFunc String
"scale"
            HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ HsExpr GhcPs -> HsExpr GhcPs -> Bool -> HsExpr GhcPs
forall a. a -> a -> Bool -> a
bool (String -> HsExpr GhcPs
mkFunc String
"flip" HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ String -> HsExpr GhcPs
mkFunc String
"div" HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ Integer -> HsExpr GhcPs
forall e. HasLit e => Integer -> e
int Integer
n)
                    (String -> HsExpr GhcPs
mkFunc String
"subtract" HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ Integer -> HsExpr GhcPs
forall e. HasLit e => Integer -> e
int Integer
1)
                    (Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
1)
            HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ HsExpr GhcPs
arbitrary
        Bool
False -> HsExpr GhcPs
arbitrary