{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE LambdaCase #-}

module Ide.Plugin.Tactic.KnownStrategies.QuickCheck where

import Control.Monad.Except (MonadError(throwError))
import Data.Bool (bool)
import Data.List (partition)
import DataCon ( DataCon, dataConName )
import Development.IDE.GHC.Compat (HsExpr, GhcPs, noLoc)
import GHC.Exts ( IsString(fromString) )
import GHC.List ( foldl' )
import GHC.SourceGen (int)
import GHC.SourceGen.Binds ( match, valBind )
import GHC.SourceGen.Expr ( case', lambda, let' )
import GHC.SourceGen.Overloaded ( App((@@)), HasList(list) )
import GHC.SourceGen.Pat ( conP )
import Ide.Plugin.Tactic.CodeGen
import Ide.Plugin.Tactic.Judgements (jGoal)
import Ide.Plugin.Tactic.Machinery (tracePrim)
import Ide.Plugin.Tactic.Types
import OccName (occNameString,  mkVarOcc, HasOccName(occName) )
import Refinery.Tactic (goal,  rule )
import TyCon (tyConName,  TyCon, tyConDataCons )
import Type ( splitTyConApp_maybe )
import Data.Generics (mkQ, everything)


------------------------------------------------------------------------------
-- | Known tactic for deriving @arbitrary :: Gen a@. This tactic splits the
-- type's data cons into terminal and inductive cases, and generates code that
-- produces a terminal if the QuickCheck size parameter is <=1, or any data con
-- otherwise. It correctly scales recursive parameters, ensuring termination.
deriveArbitrary :: TacticsM ()
deriveArbitrary :: TacticsM ()
deriveArbitrary = do
  CType
ty <- Judgement' CType -> CType
forall a. Judgement' a -> a
jGoal (Judgement' CType -> CType)
-> TacticT
     (Judgement' CType)
     (Trace, LHsExpr GhcPs)
     TacticError
     TacticState
     ExtractM
     (Judgement' CType)
-> TacticT
     (Judgement' CType)
     (Trace, LHsExpr GhcPs)
     TacticError
     TacticState
     ExtractM
     CType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TacticT
  (Judgement' CType)
  (Trace, LHsExpr GhcPs)
  TacticError
  TacticState
  ExtractM
  (Judgement' CType)
forall (m :: * -> *) jdg ext err s.
Functor m =>
TacticT jdg ext err s m jdg
goal
  case HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe (Type -> Maybe (TyCon, [Type])) -> Type -> Maybe (TyCon, [Type])
forall a b. (a -> b) -> a -> b
$ CType -> Type
unCType CType
ty of
    Just (TyCon
gen_tc, [HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe -> Just (TyCon
tc, [Type]
apps)])
        | OccName -> String
occNameString (Name -> OccName
forall name. HasOccName name => name -> OccName
occName (Name -> OccName) -> Name -> OccName
forall a b. (a -> b) -> a -> b
$ TyCon -> Name
tyConName TyCon
gen_tc) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"Gen" -> do
      (Judgement' CType
 -> RuleT
      (Judgement' CType)
      (Trace, LHsExpr GhcPs)
      TacticError
      TacticState
      ExtractM
      (Trace, LHsExpr GhcPs))
-> TacticsM ()
forall (m :: * -> *) jdg ext err s.
Monad m =>
(jdg -> RuleT jdg ext err s m ext) -> TacticT jdg ext err s m ()
rule ((Judgement' CType
  -> RuleT
       (Judgement' CType)
       (Trace, LHsExpr GhcPs)
       TacticError
       TacticState
       ExtractM
       (Trace, LHsExpr GhcPs))
 -> TacticsM ())
-> (Judgement' CType
    -> RuleT
         (Judgement' CType)
         (Trace, LHsExpr GhcPs)
         TacticError
         TacticState
         ExtractM
         (Trace, LHsExpr GhcPs))
-> TacticsM ()
forall a b. (a -> b) -> a -> b
$ \Judgement' CType
_ -> do
        let dcs :: [DataCon]
dcs = TyCon -> [DataCon]
tyConDataCons TyCon
tc
            ([Generator]
terminal, [Generator]
big) = (Generator -> Bool) -> [Generator] -> ([Generator], [Generator])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0) (Integer -> Bool) -> (Generator -> Integer) -> Generator -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Generator -> Integer
genRecursiveCount)
                        ([Generator] -> ([Generator], [Generator]))
-> [Generator] -> ([Generator], [Generator])
forall a b. (a -> b) -> a -> b
$ (DataCon -> Generator) -> [DataCon] -> [Generator]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TyCon -> [Type] -> DataCon -> Generator
mkGenerator TyCon
tc [Type]
apps) [DataCon]
dcs
            terminal_expr :: HsExpr GhcPs
terminal_expr = String -> HsExpr GhcPs
mkVal String
"terminal"
            oneof_expr :: HsExpr GhcPs
oneof_expr = String -> HsExpr GhcPs
mkVal String
"oneof"
        (Trace, LHsExpr GhcPs)
-> RuleT
     (Judgement' CType)
     (Trace, LHsExpr GhcPs)
     TacticError
     TacticState
     ExtractM
     (Trace, LHsExpr GhcPs)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( String -> Trace
tracePrim String
"deriveArbitrary"
          , SrcSpanLess (LHsExpr GhcPs) -> LHsExpr GhcPs
forall a. HasSrcSpan a => SrcSpanLess a -> a
noLoc (SrcSpanLess (LHsExpr GhcPs) -> LHsExpr GhcPs)
-> SrcSpanLess (LHsExpr GhcPs) -> LHsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
              [RawValBind] -> HsExpr GhcPs -> HsExpr GhcPs
let' [OccNameStr -> HsExpr GhcPs -> RawValBind
forall t. HasValBind t => OccNameStr -> HsExpr GhcPs -> t
valBind (String -> OccNameStr
forall a. IsString a => String -> a
fromString String
"terminal") (HsExpr GhcPs -> RawValBind) -> HsExpr GhcPs -> RawValBind
forall a b. (a -> b) -> a -> b
$ [HsExpr GhcPs] -> HsExpr GhcPs
forall e. HasList e => [e] -> e
list ([HsExpr GhcPs] -> HsExpr GhcPs) -> [HsExpr GhcPs] -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ (Generator -> HsExpr GhcPs) -> [Generator] -> [HsExpr GhcPs]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Generator -> HsExpr GhcPs
genExpr [Generator]
terminal] (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
                HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
appDollar (String -> HsExpr GhcPs
mkFunc String
"sized") (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ [Pat'] -> HsExpr GhcPs -> HsExpr GhcPs
lambda [OccName -> Pat'
forall a. BVar a => OccName -> a
bvar' (String -> OccName
mkVarOcc String
"n")] (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
                  HsExpr GhcPs -> [RawMatch] -> HsExpr GhcPs
case' (String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
"<=" (String -> HsExpr GhcPs
mkVal String
"n") (Integer -> HsExpr GhcPs
forall e. HasLit e => Integer -> e
int Integer
1))
                    [ [Pat'] -> HsExpr GhcPs -> RawMatch
match [RdrNameStr -> [Pat'] -> Pat'
conP (String -> RdrNameStr
forall a. IsString a => String -> a
fromString String
"True") []] (HsExpr GhcPs -> RawMatch) -> HsExpr GhcPs -> RawMatch
forall a b. (a -> b) -> a -> b
$
                        HsExpr GhcPs
oneof_expr HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ HsExpr GhcPs
terminal_expr
                    , [Pat'] -> HsExpr GhcPs -> RawMatch
match [RdrNameStr -> [Pat'] -> Pat'
conP (String -> RdrNameStr
forall a. IsString a => String -> a
fromString String
"False") []] (HsExpr GhcPs -> RawMatch) -> HsExpr GhcPs -> RawMatch
forall a b. (a -> b) -> a -> b
$
                        HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
appDollar HsExpr GhcPs
oneof_expr (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$
                          String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
"<>"
                            ([HsExpr GhcPs] -> HsExpr GhcPs
forall e. HasList e => [e] -> e
list ([HsExpr GhcPs] -> HsExpr GhcPs) -> [HsExpr GhcPs] -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ (Generator -> HsExpr GhcPs) -> [Generator] -> [HsExpr GhcPs]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Generator -> HsExpr GhcPs
genExpr [Generator]
big)
                            HsExpr GhcPs
terminal_expr
                    ]
          )
    Maybe (TyCon, [Type])
_ -> TacticError -> TacticsM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TacticError -> TacticsM ()) -> TacticError -> TacticsM ()
forall a b. (a -> b) -> a -> b
$ String -> CType -> TacticError
GoalMismatch String
"deriveArbitrary" CType
ty


------------------------------------------------------------------------------
-- | Helper data type for the generator of a specific data con.
data Generator = Generator
  { Generator -> Integer
genRecursiveCount :: Integer
  , Generator -> HsExpr GhcPs
genExpr :: HsExpr GhcPs
  }


------------------------------------------------------------------------------
-- | Make a 'Generator' for a given tycon instantiated with the given @[Type]@.
mkGenerator :: TyCon -> [Type] -> DataCon -> Generator
mkGenerator :: TyCon -> [Type] -> DataCon -> Generator
mkGenerator TyCon
tc [Type]
apps DataCon
dc = do
  let dc_expr :: HsExpr GhcPs
dc_expr   = OccName -> HsExpr GhcPs
forall a. Var a => OccName -> a
var' (OccName -> HsExpr GhcPs) -> OccName -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ Name -> OccName
forall name. HasOccName name => name -> OccName
occName (Name -> OccName) -> Name -> OccName
forall a b. (a -> b) -> a -> b
$ DataCon -> Name
dataConName DataCon
dc
      args :: [Type]
args = DataCon -> [Type] -> [Type]
dataConInstOrigArgTys' DataCon
dc [Type]
apps
      num_recursive_calls :: Integer
num_recursive_calls = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Integer] -> Integer) -> [Integer] -> Integer
forall a b. (a -> b) -> a -> b
$ (Type -> Integer) -> [Type] -> [Integer]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Integer -> Integer -> Bool -> Integer
forall a. a -> a -> Bool -> a
bool Integer
0 Integer
1 (Bool -> Integer) -> (Type -> Bool) -> Type -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyCon -> Type -> Bool
doesTypeContain TyCon
tc) [Type]
args
      mkArbitrary :: Type -> HsExpr GhcPs
mkArbitrary = TyCon -> Integer -> Type -> HsExpr GhcPs
mkArbitraryCall TyCon
tc Integer
num_recursive_calls
  Integer -> HsExpr GhcPs -> Generator
Generator Integer
num_recursive_calls (HsExpr GhcPs -> Generator) -> HsExpr GhcPs -> Generator
forall a b. (a -> b) -> a -> b
$ case [Type]
args of
    []  -> String -> HsExpr GhcPs
mkFunc String
"pure" HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ HsExpr GhcPs
dc_expr
    (Type
a : [Type]
as) ->
      (HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs)
-> HsExpr GhcPs -> [HsExpr GhcPs] -> HsExpr GhcPs
forall a b. (b -> a -> b) -> b -> [a] -> b
foldl'
        (String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
"<*>")
        (String -> HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
infixCall String
"<$>" HsExpr GhcPs
dc_expr (HsExpr GhcPs -> HsExpr GhcPs) -> HsExpr GhcPs -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ Type -> HsExpr GhcPs
mkArbitrary Type
a)
        ((Type -> HsExpr GhcPs) -> [Type] -> [HsExpr GhcPs]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> HsExpr GhcPs
mkArbitrary [Type]
as)


------------------------------------------------------------------------------
-- | Check if the given 'TyCon' exists anywhere in the 'Type'.
doesTypeContain :: TyCon -> Type -> Bool
doesTypeContain :: TyCon -> Type -> Bool
doesTypeContain TyCon
recursive_tc =
  (Bool -> Bool -> Bool) -> GenericQ Bool -> GenericQ Bool
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Bool -> Bool -> Bool
(||) (GenericQ Bool -> GenericQ Bool) -> GenericQ Bool -> GenericQ Bool
forall a b. (a -> b) -> a -> b
$ Bool -> (TyCon -> Bool) -> a -> Bool
forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ Bool
False (TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
recursive_tc)


------------------------------------------------------------------------------
-- | Generate the correct sort of call to @arbitrary@. For recursive calls, we
-- need to scale down the size parameter, either by a constant factor of 1 if
-- it's the only recursive parameter, or by @`div` n@ where n is the number of
-- recursive parameters. For all other types, just call @arbitrary@ directly.
mkArbitraryCall :: TyCon -> Integer -> Type -> HsExpr GhcPs
mkArbitraryCall :: TyCon -> Integer -> Type -> HsExpr GhcPs
mkArbitraryCall TyCon
recursive_tc Integer
n Type
ty =
  let arbitrary :: HsExpr GhcPs
arbitrary = String -> HsExpr GhcPs
mkFunc String
"arbitrary"
   in case TyCon -> Type -> Bool
doesTypeContain TyCon
recursive_tc Type
ty of
        Bool
True ->
          String -> HsExpr GhcPs
mkFunc String
"scale"
            HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ HsExpr GhcPs -> HsExpr GhcPs -> Bool -> HsExpr GhcPs
forall a. a -> a -> Bool -> a
bool (String -> HsExpr GhcPs
mkFunc String
"flip" HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ String -> HsExpr GhcPs
mkFunc String
"div" HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ Integer -> HsExpr GhcPs
forall e. HasLit e => Integer -> e
int Integer
n)
                    (String -> HsExpr GhcPs
mkFunc String
"subtract" HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ Integer -> HsExpr GhcPs
forall e. HasLit e => Integer -> e
int Integer
1)
                    (Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
1)
            HsExpr GhcPs -> HsExpr GhcPs -> HsExpr GhcPs
forall e. App e => e -> e -> e
@@ HsExpr GhcPs
arbitrary
        Bool
False -> HsExpr GhcPs
arbitrary