{-# LANGUAGE ScopedTypeVariables, TemplateHaskell, StandaloneDeriving, TypeFamilies, GADTs, PatternSynonyms , ViewPatterns, TypeOperators, TypeApplications, TypeInType, PatternSynonyms , TupleSections, Trustworthy #-} -- * Type family to proofs -- * TODO -- - dissolve BIG DOUBT -- - allow for unlifted arrow arguments? -- - remove the `Lift Exp` hack (or revert to what I had before), -- see https://ghc.haskell.org/t/14296 -- - clients need to activate ScopedTypeVariables -- - type family a /=/ b where, how to name the artifacts? -- - take advantage of 'InjectivityAnn'? module TyFamWitnesses (witnesses, pattern Lifted, pattern Fun') where import GHC.Exts import Data.Kind import Language.Haskell.TH hiding (Type) import Language.Haskell.TH.Syntax hiding (Type) import qualified Language.Haskell.TH.Syntax as HS(Type) import Data.Maybe import Type.Reflection import Data.Type.Equality hiding (apply) import Data.Char (toLower) import Unsafe.Coerce (unsafeCoerce) import Debug.Trace import Data.Map.Strict hiding (mapMaybe, take, foldr, foldl', filter) import Data.List (foldl', nub) import Data.Function witnesses :: DecsQ -> DecsQ witnesses decsQ = do decs <- decsQ datasNames <- sequence (buildGADT <$> decs) let (datas, namesConstrsFams) = let dns = mapMaybe id datasNames in unzip dns funcs <- sequence (buildReifier <$> namesConstrsFams) pure $ decs ++ datas ++ concat funcs buildGADT :: Dec -> Q (Maybe (Dec, (Name, [Exp], Dec))) buildGADT fam@(ClosedTypeFamilyD (TypeFamilyHead name tyvars _ _) clauses) = do ns <- sequence constrNames n <- newName (occ ++ "Refl") pure (pure (DataD [] n tyvars Nothing (constrs ns n) [], (n, ConE <$> ns, fam))) where Name (OccName occ) _ = name constrs ns n = zipWith (clauseToCtor name n) ns clauses constrNames = take (length clauses) $ generateConstrNames name buildGADT _ = pure Nothing generateConstrNames :: Name -> [Q Name] generateConstrNames (Name (OccName occ) _) = newName . (occ ++) . show <$> [0..] clauseToCtor :: Name -> Name -> Name -> TySynEqn -> Con clauseToCtor fam n constrName (TySynEqn formals returnt) = ForallC [] [AppT (AppT EqualityT (foldl' AppT (ConT fam) formals)) returnt] $ GadtC [constrName] [] (foldl' AppT (ConT n) formals) instance Lift Exp where lift = pure pattern TRACE <- ((`traceShow` ()) -> ()) buildReifier :: (Name, [Exp], Dec) -> DecsQ buildReifier (dname, constrs, ClosedTypeFamilyD (TypeFamilyHead name formals _ _) clauses) = do n <- newName ("reify_" ++ occ) tyr <- [t|TypeRep|] mbe <- [t|Maybe|] hrefl <- [p|HRefl|] prefl <- [p|Refl|] refl <- [e|Refl|] noth <- [e|Nothing|] unsCoe <- [e|unsafeCoerce|] ConP just _ <- [p|Just|] propEq <- [t|(:~:)|] pre <- [e|pure|] eqTypeRep <- [e|eqTypeRep|] ConP fun _ <- [p|Fun'|] ConP app _ <- [p|App|] let el `appEqTypeRep` er = eqTypeRep `AppE` el `AppE` er result = (pre `AppE`) coercedRefl = unsCoe `AppE` refl justHRefl = ConP just [hrefl] joinKeys = unionWith (++) t2p :: HS.Type -> Q (Pat, Name `Map` [Name]) t2p (VarT name@(Name (OccName n) _)) = forVar <$> newName n where forVar newName = (VarP newName, name `singleton` pure newName) t2p ArrowT = pure (ConP fun [], empty) t2p con@ConT{} = (, empty) <$> [p|(eqTypeRep (typeRep @ $(pure con)) -> Just HRefl)|] t2p (fun `AppT` arg) = do (fp, fns) <- t2p fun (ap, ans) <- t2p arg pure (fp `appP` ap, joinKeys fns ans) where ConP c args `appP` p = ConP c $ args ++ [p] a `appP` b = InfixP a app b t2p con@LitT{} = let lit = pure con in (, empty) <$> [p|(eqTypeRep (typeRep @ $lit) -> Just HRefl)|] t2p con@PromotedT{} = let lit = pure con in (, empty) <$> [p|(eqTypeRep (typeRep @ $lit) -> Just HRefl)|] t2p wtf = error $ "what is this? " ++ show wtf total :: TySynEqn -> Bool total (TySynEqn actuals _) = all var actuals && length actuals == length (nub actuals) where var VarT{} = True var _ = False clause :: TySynEqn -> Exp -> Q Clause -- BIG DOUBT: does this only work when this is the first line? Probably not a problem clause (TySynEqn [p0, p1] returnt) constr | p0 == p1 = pure $ Clause formalsPatterns (GuardedB [(PatG [BindS justHRefl formalsEquated], result constr)]) [] where formalsPatterns = tv VarP <$> formals formalsEquated = foldl1 appEqTypeRep $ tv VarE <$> formals clause (TySynEqn actuals returnt) constr = do patsNonlins <- t2p `mapM` actuals let pats = fst <$> patsNonlins nonlins = nub . elems $ foldl' joinKeys empty $ snd <$> patsNonlins shiftZip ns = BindS justHRefl <$> (zipWith (appEqTypeRep `on` VarE) ns $ tail ns) binders = shiftZip `concatMap` nonlins formalEQactual v = (tv VarT v `AppEqT`) constraints = zipWith formalEQactual formals actuals appliedActuals = foldl' AppT (ConT name) actuals postulate = BindS prefl (SigE coercedRefl (ForallT [] constraints $ propEq `AppT` returnt `AppT` appliedActuals)) pure $ Clause pats (GuardedB [(PatG $ binders ++ [postulate], result constr)]) [] cs <- sequence $ zipWith clause clauses constrs let partial = not $ or $ total <$> clauses let cs' = if partial then cs ++ [Clause (zipWith const (repeat WildP) formals) (NormalB noth) []] else cs pure [ SigD n (ForallT formals [] (foldr AppArrowT (mbe `AppT` appliedFormals) $ AppT tyr <$> formalsTypes)) , FunD n cs'] where Name (OccName occ) _ = name formalsTypes = tv VarT <$> formals tv how (PlainTV formal) = how formal tv how (KindedTV formal _) = how formal appliedFormals = foldl' AppT (ConT dname) formalsTypes buildReifier _ = pure [] pattern AppArrowT l r = AppT (AppT ArrowT l) r infixr 1 `AppArrowT` pattern AppEqT l r = AppT (AppT EqualityT l) r infixr 2 `AppEqT` pattern Lifted :: forall k (a :: k). () => forall (t :: Type). t ~~ a => TypeRep t -> TypeRep a pattern Lifted c <- ((\r -> (typeRepKind r `testEquality` typeRep @Type, r)) -> (Just Refl, c)) pattern Fun' :: forall k (f :: k). () => forall c d. (d -> c) ~~ f => TypeRep d -> TypeRep c -> TypeRep f pattern d `Fun'` c <- Lifted d `Fun` Lifted c