{-# LANGUAGE ExplicitNamespaces, MultiWayIf, PatternGuards, TemplateHaskell #-} {-# LANGUAGE TupleSections #-} module Proof.Propositional.TH where import Proof.Propositional.Empty import Proof.Propositional.Inhabited import Control.Arrow (Kleisli (..), second) import Control.Monad (forM, zipWithM) import Data.Foldable (asum) import Data.Map (Map) import qualified Data.Map as M import Data.Maybe (fromJust) import Data.Monoid ((<>)) import Data.Type.Equality ((:~:) (..)) import Language.Haskell.TH (DecsQ, Lit (CharL, IntegerL)) import Language.Haskell.TH (Name, Q, TypeQ, isInstance) import Language.Haskell.TH (newName, ppr) import Language.Haskell.TH.Desugar (DClause (..), DCon (..)) import Language.Haskell.TH.Desugar (DConFields (..), DCxt, DDec (..)) import Language.Haskell.TH.Desugar (DExp (..), DInfo (..)) import Language.Haskell.TH.Desugar (DLetDec (DFunD)) import Language.Haskell.TH.Desugar (DPat (DConPa, DVarPa), DPred (..)) import Language.Haskell.TH.Desugar (DTyVarBndr (..), DType (..)) import Language.Haskell.TH.Desugar (Overlap (Overlapping), desugar) import Language.Haskell.TH.Desugar (dsReify, expandType, substTy) import Language.Haskell.TH.Desugar (sweeten) -- | Macro to automatically derive @'Empty'@ instance for -- concrete (variable-free) types which may contain products. refute :: TypeQ -> DecsQ refute tps = do tp <- expandType =<< desugar =<< tps let Just (_, tyName, args) = splitType tp mkInst dxt cls = return $ sweeten [DInstanceD (Just Overlapping) dxt (DAppT (DConT ''Empty) (foldl DAppT (DConT tyName) args)) [DLetDec $ DFunD 'eliminate cls] ] if tyName == ''(:~:) then do let [l, r] = args v <- newName "_v" dist <- compareType l r case dist of NonEqual -> mkInst [] [DClause [] $ DLamE [v] (DCaseE (DVarE v) []) ] Equal -> fail $ "Equal: " ++ show (ppr $ sweeten l) ++ " ~ " ++ show (ppr $ sweeten r) Undecidable -> fail $ "No enough info to check non-equality: " ++ show (ppr $ sweeten l) ++ " ~ " ++ show (ppr $ sweeten r) else do (dxt, cons) <- resolveSubsts args . fromJust =<< dsReify tyName Just cls <- sequence <$> mapM buildRefuteClause cons mkInst dxt cls -- | Macro to automatically derive @'Inhabited'@ instance for -- concrete (variable-free) types which may contain sums. prove :: TypeQ -> DecsQ prove tps = do tp <- expandType =<< desugar =<< tps let Just (_, tyName, args) = splitType tp mkInst dxt cls = return $ sweeten [DInstanceD (Just Overlapping) dxt (DAppT (DConT ''Inhabited) (foldl DAppT (DConT tyName) args)) [DLetDec $ DFunD 'trivial cls] ] isNum <- isInstance ''Num [sweeten tp] if | isNum -> mkInst [] [DClause [] $ DLitE $ IntegerL 0 ] | tyName == ''Char -> mkInst [] [DClause [] $ DLitE $ CharL '\NUL'] | tyName == ''(:~:) -> do let [l, r] = args dist <- compareType l r case dist of NonEqual -> fail $ "Equal: " ++ show (ppr $ sweeten l) ++ " ~ " ++ show (ppr $ sweeten r) Equal -> mkInst [] [DClause [] $ DConE 'Refl ] Undecidable -> fail $ "No enough info to check non-equality: " ++ show (ppr $ sweeten l) ++ " ~ " ++ show (ppr $ sweeten r) | otherwise -> do (dxt, cons) <- resolveSubsts args . fromJust =<< dsReify tyName Just cls <- asum <$> mapM buildProveClause cons mkInst dxt [cls] buildClause :: Name -> (DType -> Q b) -> (DType -> b -> DExp) -> (Name -> [Maybe DExp] -> Maybe DExp) -> (Name -> [b] -> [DPat]) -> DCon -> Q (Maybe DClause) buildClause clsName genPlaceHolder buildFactor flattenExps toPats (DCon _ _ cName flds _) = do let tys = fieldsVars flds varDic <- mapM genPlaceHolder tys fmap (DClause $ toPats cName varDic) . flattenExps cName <$> zipWithM tryProc tys varDic where tryProc ty name = do isEmpty <- isInstance clsName . (:[]) $ sweeten ty return $ if isEmpty then Just $ buildFactor ty name else Nothing buildRefuteClause :: DCon -> Q (Maybe DClause) buildRefuteClause = buildClause ''Empty (const $ newName "_x") (const $ (DVarE 'eliminate `DAppE`) . DVarE) (const asum) (\cName ps -> [DConPa cName $ map DVarPa ps]) buildProveClause :: DCon -> Q (Maybe DClause) buildProveClause = buildClause ''Inhabited (const $ return ()) (const $ const $ DVarE 'trivial) (\ con args -> foldl DAppE (DConE con) <$> sequence args ) (const $ const []) fieldsVars :: DConFields -> [DType] fieldsVars (DNormalC fs) = map snd fs fieldsVars (DRecC fs) = map (\(_,_,c) -> c) fs resolveSubsts :: [DType] -> DInfo -> Q (DCxt, [DCon]) resolveSubsts args info = do case info of (DTyConI (DDataD _ cxt _ tvbs dcons _) _) -> do let dic = M.fromList $ zip (map dtvbToName tvbs) args (cxt , ) <$> mapM (substDCon dic) dcons -- (DTyConI (DOpenTypeFamilyD n) _) -> return [] -- (DTyConI (DClosedTypeFamilyD _ ddec2) minst) -> return [] -- (DTyConI (DDataFamilyD _ ddec2) minst) -> return [] -- (DTyConI (DDataInstD _ ddec2 ddec3 ddec4 ddec5 ddec6) minst) -> return [] (DTyConI _ _) -> fail "Not supported data ty" _ -> fail "Please pass data-type" type SubstDic = Map Name DType substDCon :: SubstDic -> DCon -> Q DCon substDCon dic (DCon forall'd cxt conName fields mPhantom) = DCon forall'd cxt conName <$> substFields dic fields <*> mapM (substTy dic) mPhantom substFields :: SubstDic -> DConFields -> Q DConFields substFields subst (DNormalC fs) = DNormalC <$> mapM (runKleisli $ second $ Kleisli $ substTy subst) fs substFields subst (DRecC fs) = DRecC <$> forM fs (\(a,b,c) -> (a, b ,) <$> substTy subst c) dtvbToName :: DTyVarBndr -> Name dtvbToName (DPlainTV n) = n dtvbToName (DKindedTV n _) = n splitType :: DType -> Maybe ([Name], Name, [DType]) splitType (DForallT vs _ t) = (\(a,b,c) -> (map dtvbToName vs ++ a, b, c)) <$> splitType t splitType (DAppT t1 t2) = (\(a,b,c) -> (a, b, c ++ [t2])) <$> splitType t1 splitType (DSigT t _) = splitType t splitType (DVarT _) = Nothing splitType (DConT n) = Just ([], n, []) splitType DArrowT = Just ([], ''(->), []) splitType (DLitT _) = Nothing splitType DWildCardT = Nothing splitType DStarT = Nothing data EqlJudge = NonEqual | Undecidable | Equal deriving (Read, Show, Eq, Ord) instance Monoid EqlJudge where NonEqual `mappend` _ = NonEqual Undecidable `mappend` NonEqual = NonEqual Undecidable `mappend` _ = Undecidable Equal `mappend` m = m mempty = Equal compareType :: DType -> DType -> Q EqlJudge compareType t0 s0 = do t <- expandType t0 s <- expandType s0 compareType' t s compareType' :: DType -> DType -> Q EqlJudge compareType' (DSigT t1 t2) (DSigT s1 s2) = (<>) <$> compareType' t1 s1 <*> compareType' t2 s2 compareType' (DSigT t _) s = compareType' t s compareType' t (DSigT s _) = compareType' t s compareType' (DVarT t) (DVarT s) | t == s = return Equal | otherwise = return Undecidable compareType' (DVarT _) _ = return Undecidable compareType' _ (DVarT _) = return Undecidable compareType' DWildCardT _ = return Undecidable compareType' _ DWildCardT = return Undecidable compareType' (DForallT tTvBs tCxt t) (DForallT sTvBs sCxt s) | length tTvBs == length sTvBs = do let dic = M.fromList $ zip (map dtvbToName sTvBs) (map (DVarT . dtvbToName) tTvBs) s' <- substTy dic s pd <- compareCxt tCxt =<< mapM (substPred dic) sCxt bd <- compareType' t s' return (pd <> bd) | otherwise = return NonEqual compareType' (DForallT _ _ _) _ = return NonEqual compareType' (DAppT t1 t2) (DAppT s1 s2) = (<>) <$> compareType' t1 s1 <*> compareType' t2 s2 compareType' (DConT t) (DConT s) | t == s = return Equal | otherwise = return NonEqual compareType' (DConT _) _ = return NonEqual compareType' DArrowT DArrowT = return Equal compareType' DArrowT _ = return NonEqual compareType' (DLitT t) (DLitT s) | t == s = return Equal | otherwise = return NonEqual compareType' (DLitT _) _ = return NonEqual compareType' DStarT DStarT = return NonEqual compareType' _ _ = return NonEqual compareCxt :: DCxt -> DCxt -> Q EqlJudge compareCxt l r = mconcat <$> zipWithM comparePred l r comparePred :: DPred -> DPred -> Q EqlJudge comparePred DWildCardPr _ = return Undecidable comparePred _ DWildCardPr = return Undecidable comparePred (DVarPr l) (DVarPr r) | l == r = return Equal comparePred (DVarPr _) _ = return Undecidable comparePred _ (DVarPr _) = return Undecidable comparePred (DSigPr l t) (DSigPr r s) = (<>) <$> compareType' t s <*> comparePred l r comparePred (DSigPr l _) r = comparePred l r comparePred l (DSigPr r _) = comparePred l r comparePred (DAppPr l1 l2) (DAppPr r1 r2) = do l2' <- expandType l2 r2' <- expandType r2 (<>) <$> comparePred l1 r1 <*> compareType' l2' r2' comparePred (DAppPr _ _) _ = return NonEqual comparePred (DConPr l) (DConPr r) | l == r = return Equal | otherwise = return NonEqual comparePred (DConPr _) _ = return NonEqual substPred :: SubstDic -> DPred -> Q DPred substPred dic (DAppPr p1 p2) = DAppPr <$> substPred dic p1 <*> (expandType =<< substTy dic p2) substPred dic (DSigPr p knd) = DSigPr <$> substPred dic p <*> (expandType =<< substTy dic knd) substPred dic prd@(DVarPr p) | Just (DVarT t) <- M.lookup p dic = return $ DVarPr t | Just (DConT t) <- M.lookup p dic = return $ DConPr t | otherwise = return prd substPred _ t = return t