{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -- | -- This module implements the generic deriving elaboration that takes place during desugaring. -- module Language.PureScript.Sugar.TypeClasses.Deriving (deriveInstances) where import Prelude () import Prelude.Compat import Data.List (foldl', find, sortBy) import Data.Maybe (fromMaybe) import Data.Ord (comparing) import Control.Arrow (second) import Control.Monad (replicateM) import Control.Monad.Supply.Class (MonadSupply) import Control.Monad.Error.Class (MonadError(..)) import Language.PureScript.Crash import Language.PureScript.AST import Language.PureScript.Environment import Language.PureScript.Errors import Language.PureScript.Names import Language.PureScript.Types import qualified Language.PureScript.Constants as C -- | Elaborates deriving instance declarations by code generation. deriveInstances :: (MonadError MultipleErrors m, MonadSupply m) => Module -> m Module deriveInstances (Module ss coms mn ds exts) = Module ss coms mn <$> mapM (deriveInstance mn ds) ds <*> pure exts -- | Takes a declaration, and if the declaration is a deriving TypeInstanceDeclaration, -- elaborates that into an instance declaration via code generation. deriveInstance :: (MonadError MultipleErrors m, MonadSupply m) => ModuleName -> [Declaration] -> Declaration -> m Declaration deriveInstance mn ds (TypeInstanceDeclaration nm deps className tys@[ty] DerivedInstance) | className == Qualified (Just dataGeneric) (ProperName C.generic) , Just (Qualified mn' tyCon, args) <- unwrapTypeConstructor ty , mn == fromMaybe mn mn' = TypeInstanceDeclaration nm deps className tys . ExplicitInstance <$> deriveGeneric mn ds tyCon args | className == Qualified (Just (ModuleName [ ProperName "Prelude" ])) (ProperName "Eq") , Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty , mn == fromMaybe mn mn' = TypeInstanceDeclaration nm deps className tys . ExplicitInstance <$> deriveEq mn ds tyCon | className == Qualified (Just (ModuleName [ ProperName "Prelude" ])) (ProperName "Ord") , Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty , mn == fromMaybe mn mn' = TypeInstanceDeclaration nm deps className tys . ExplicitInstance <$> deriveOrd mn ds tyCon deriveInstance _ _ (TypeInstanceDeclaration _ _ className tys DerivedInstance) = throwError . errorMessage $ CannotDerive className tys deriveInstance mn ds (PositionedDeclaration pos com d) = PositionedDeclaration pos com <$> deriveInstance mn ds d deriveInstance _ _ e = return e unwrapTypeConstructor :: Type -> Maybe (Qualified (ProperName 'TypeName), [Type]) unwrapTypeConstructor = fmap (second reverse) . go where go (TypeConstructor tyCon) = Just (tyCon, []) go (TypeApp ty arg) = do (tyCon, args) <- go ty return (tyCon, arg : args) go _ = Nothing dataGeneric :: ModuleName dataGeneric = ModuleName [ ProperName "Data", ProperName "Generic" ] dataMaybe :: ModuleName dataMaybe = ModuleName [ ProperName "Data", ProperName "Maybe" ] typesProxy :: ModuleName typesProxy = ModuleName [ ProperName "Type", ProperName "Proxy" ] deriveGeneric :: forall m. (MonadError MultipleErrors m, MonadSupply m) => ModuleName -> [Declaration] -> ProperName 'TypeName -> [Type] -> m [Declaration] deriveGeneric mn ds tyConNm dargs = do tyCon <- findTypeDecl tyConNm ds toSpine <- mkSpineFunction tyCon fromSpine <- mkFromSpineFunction tyCon let toSignature = mkSignatureFunction tyCon dargs return [ ValueDeclaration (Ident C.toSpine) Public [] (Right toSpine) , ValueDeclaration (Ident C.fromSpine) Public [] (Right fromSpine) , ValueDeclaration (Ident C.toSignature) Public [] (Right toSignature) ] where mkSpineFunction :: Declaration -> m Expr mkSpineFunction (DataDeclaration _ _ _ args) = do x <- freshIdent' lamCase x <$> mapM mkCtorClause args where prodConstructor :: Expr -> Expr prodConstructor = App (Constructor (Qualified (Just dataGeneric) (ProperName "SProd"))) recordConstructor :: Expr -> Expr recordConstructor = App (Constructor (Qualified (Just dataGeneric) (ProperName "SRecord"))) mkCtorClause :: (ProperName 'ConstructorName, [Type]) -> m CaseAlternative mkCtorClause (ctorName, tys) = do idents <- replicateM (length tys) freshIdent' return $ CaseAlternative [ConstructorBinder (Qualified (Just mn) ctorName) (map VarBinder idents)] (Right (caseResult idents)) where caseResult idents = App (prodConstructor (Literal . StringLiteral . showQualified runProperName $ Qualified (Just mn) ctorName)) . Literal . ArrayLiteral $ zipWith toSpineFun (map (Var . Qualified Nothing) idents) tys toSpineFun :: Expr -> Type -> Expr toSpineFun i r | Just rec <- objectType r = lamNull . recordConstructor . Literal . ArrayLiteral . map (\(str,typ) -> Literal $ ObjectLiteral [ ("recLabel", Literal (StringLiteral str)) , ("recValue", toSpineFun (Accessor str i) typ) ] ) $ decomposeRec rec toSpineFun i _ = lamNull $ App (mkGenVar (Ident C.toSpine)) i mkSpineFunction (PositionedDeclaration _ _ d) = mkSpineFunction d mkSpineFunction _ = internalError "mkSpineFunction: expected DataDeclaration" mkSignatureFunction :: Declaration -> [Type] -> Expr mkSignatureFunction (DataDeclaration _ name tyArgs args) classArgs = lamNull . mkSigProd $ map mkProdClause args where mkSigProd :: [Expr] -> Expr mkSigProd = App (App (Constructor (Qualified (Just dataGeneric) (ProperName "SigProd"))) (Literal (StringLiteral (showQualified runProperName (Qualified (Just mn) name)))) ) . Literal . ArrayLiteral mkSigRec :: [Expr] -> Expr mkSigRec = App (Constructor (Qualified (Just dataGeneric) (ProperName "SigRecord"))) . Literal . ArrayLiteral proxy :: Type -> Type proxy = TypeApp (TypeConstructor (Qualified (Just typesProxy) (ProperName "Proxy"))) mkProdClause :: (ProperName 'ConstructorName, [Type]) -> Expr mkProdClause (ctorName, tys) = Literal $ ObjectLiteral [ ("sigConstructor", Literal (StringLiteral (showQualified runProperName (Qualified (Just mn) ctorName)))) , ("sigValues", Literal . ArrayLiteral . map (mkProductSignature . instantiate) $ tys) ] mkProductSignature :: Type -> Expr mkProductSignature r | Just rec <- objectType r = lamNull . mkSigRec $ [ Literal $ ObjectLiteral [ ("recLabel", Literal (StringLiteral str)) , ("recValue", mkProductSignature typ) ] | (str, typ) <- decomposeRec rec ] mkProductSignature typ = lamNull $ App (mkGenVar (Ident C.toSignature)) (TypedValue False (mkGenVar (Ident "anyProxy")) (proxy typ)) instantiate = replaceAllTypeVars (zipWith (\(arg, _) ty -> (arg, ty)) tyArgs classArgs) mkSignatureFunction (PositionedDeclaration _ _ d) classArgs = mkSignatureFunction d classArgs mkSignatureFunction _ _ = internalError "mkSignatureFunction: expected DataDeclaration" mkFromSpineFunction :: Declaration -> m Expr mkFromSpineFunction (DataDeclaration _ _ _ args) = do x <- freshIdent' lamCase x <$> (addCatch <$> mapM mkAlternative args) where mkJust :: Expr -> Expr mkJust = App (Constructor (Qualified (Just dataMaybe) (ProperName "Just"))) mkNothing :: Expr mkNothing = Constructor (Qualified (Just dataMaybe) (ProperName "Nothing")) prodBinder :: [Binder] -> Binder prodBinder = ConstructorBinder (Qualified (Just dataGeneric) (ProperName "SProd")) recordBinder :: [Binder] -> Binder recordBinder = ConstructorBinder (Qualified (Just dataGeneric) (ProperName "SRecord")) mkAlternative :: (ProperName 'ConstructorName, [Type]) -> m CaseAlternative mkAlternative (ctorName, tys) = do idents <- replicateM (length tys) freshIdent' return $ CaseAlternative [ prodBinder [ LiteralBinder (StringLiteral (showQualified runProperName (Qualified (Just mn) ctorName))) , LiteralBinder (ArrayLiteral (map VarBinder idents)) ] ] . Right $ liftApplicative (mkJust $ Constructor (Qualified (Just mn) ctorName)) (zipWith fromSpineFun (map (Var . Qualified Nothing) idents) tys) addCatch :: [CaseAlternative] -> [CaseAlternative] addCatch = (++ [catchAll]) where catchAll = CaseAlternative [NullBinder] (Right mkNothing) fromSpineFun :: Expr -> Type -> Expr fromSpineFun e r | Just rec <- objectType r = App (lamCase (Ident "r") [ mkRecCase (decomposeRec rec) , CaseAlternative [NullBinder] (Right mkNothing) ]) (App e (mkPrelVar (Ident "unit"))) fromSpineFun e _ = App (mkGenVar (Ident C.fromSpine)) (App e (mkPrelVar (Ident "unit"))) mkRecCase :: [(String, Type)] -> CaseAlternative mkRecCase rs = CaseAlternative [ recordBinder [ LiteralBinder (ArrayLiteral (map (VarBinder . Ident . fst) rs)) ] ] . Right $ liftApplicative (mkRecFun rs) (map (\(x, y) -> fromSpineFun (Accessor "recValue" (mkVar (Ident x))) y) rs) mkRecFun :: [(String, Type)] -> Expr mkRecFun xs = mkJust $ foldr lam recLiteral (map (Ident . fst) xs) where recLiteral = Literal . ObjectLiteral $ map (\(s,_) -> (s, mkVar (Ident s))) xs mkFromSpineFunction (PositionedDeclaration _ _ d) = mkFromSpineFunction d mkFromSpineFunction _ = internalError "mkFromSpineFunction: expected DataDeclaration" -- Helpers liftApplicative :: Expr -> [Expr] -> Expr liftApplicative = foldl' (\x e -> App (App (mkPrelVar (Ident "apply")) x) e) mkPrelVar :: Ident -> Expr mkPrelVar = mkVarMn (Just (ModuleName [ProperName C.prelude])) mkGenVar :: Ident -> Expr mkGenVar = mkVarMn (Just (ModuleName [ProperName "Data", ProperName C.generic])) deriveEq :: forall m. (MonadError MultipleErrors m, MonadSupply m) => ModuleName -> [Declaration] -> ProperName 'TypeName -> m [Declaration] deriveEq mn ds tyConNm = do tyCon <- findTypeDecl tyConNm ds eqFun <- mkEqFunction tyCon return [ ValueDeclaration (Ident C.eq) Public [] (Right eqFun) ] where mkEqFunction :: Declaration -> m Expr mkEqFunction (DataDeclaration _ _ _ args) = do x <- freshIdent "x" y <- freshIdent "y" lamCase2 x y <$> (addCatch <$> mapM mkCtorClause args) mkEqFunction (PositionedDeclaration _ _ d) = mkEqFunction d mkEqFunction _ = internalError "mkEqFunction: expected DataDeclaration" preludeConj :: Expr -> Expr -> Expr preludeConj = App . App (Var (Qualified (Just (ModuleName [ProperName C.prelude])) (Ident C.conj))) preludeEq :: Expr -> Expr -> Expr preludeEq = App . App (Var (Qualified (Just (ModuleName [ProperName C.prelude])) (Ident C.eq))) addCatch :: [CaseAlternative] -> [CaseAlternative] addCatch xs | length xs /= 1 = xs ++ [catchAll] | otherwise = xs -- Avoid redundant case where catchAll = CaseAlternative [NullBinder, NullBinder] (Right (Literal (BooleanLiteral False))) mkCtorClause :: (ProperName 'ConstructorName, [Type]) -> m CaseAlternative mkCtorClause (ctorName, tys) = do identsL <- replicateM (length tys) (freshIdent "l") identsR <- replicateM (length tys) (freshIdent "r") let tests = zipWith3 toEqTest (map (Var . Qualified Nothing) identsL) (map (Var . Qualified Nothing) identsR) tys return $ CaseAlternative [caseBinder identsL, caseBinder identsR] (Right (conjAll tests)) where caseBinder idents = ConstructorBinder (Qualified (Just mn) ctorName) (map VarBinder idents) conjAll :: [Expr] -> Expr conjAll [] = Literal (BooleanLiteral True) conjAll xs = foldl1 preludeConj xs toEqTest :: Expr -> Expr -> Type -> Expr toEqTest l r ty | Just rec <- objectType ty = conjAll . map (\(str, typ) -> toEqTest (Accessor str l) (Accessor str r) typ) $ decomposeRec rec toEqTest l r _ = preludeEq l r deriveOrd :: forall m. (MonadError MultipleErrors m, MonadSupply m) => ModuleName -> [Declaration] -> ProperName 'TypeName -> m [Declaration] deriveOrd mn ds tyConNm = do tyCon <- findTypeDecl tyConNm ds compareFun <- mkCompareFunction tyCon return [ ValueDeclaration (Ident C.compare) Public [] (Right compareFun) ] where mkCompareFunction :: Declaration -> m Expr mkCompareFunction (DataDeclaration _ _ _ args) = do x <- freshIdent "x" y <- freshIdent "y" lamCase2 x y <$> (addCatch . concat <$> mapM mkCtorClauses (splitLast args)) mkCompareFunction (PositionedDeclaration _ _ d) = mkCompareFunction d mkCompareFunction _ = internalError "mkCompareFunction: expected DataDeclaration" splitLast :: [a] -> [(a, Bool)] splitLast [] = [] splitLast [x] = [(x, True)] splitLast (x : xs) = (x, False) : splitLast xs addCatch :: [CaseAlternative] -> [CaseAlternative] addCatch xs | null xs = [catchAll] -- No type constructors | otherwise = xs where catchAll = CaseAlternative [NullBinder, NullBinder] (Right (preludeCtor "EQ")) preludeCtor :: String -> Expr preludeCtor = Constructor . Qualified (Just (ModuleName [ProperName C.prelude])) . ProperName preludeCompare :: Expr -> Expr -> Expr preludeCompare = App . App (Var (Qualified (Just (ModuleName [ProperName C.prelude])) (Ident C.compare))) mkCtorClauses :: ((ProperName 'ConstructorName, [Type]), Bool) -> m [CaseAlternative] mkCtorClauses ((ctorName, tys), isLast) = do identsL <- replicateM (length tys) (freshIdent "l") identsR <- replicateM (length tys) (freshIdent "r") let tests = zipWith3 toOrdering (map (Var . Qualified Nothing) identsL) (map (Var . Qualified Nothing) identsR) tys extras | not isLast = [ CaseAlternative [ ConstructorBinder (Qualified (Just mn) ctorName) (replicate (length tys) NullBinder) , NullBinder ] (Right (preludeCtor "LT")) , CaseAlternative [ NullBinder , ConstructorBinder (Qualified (Just mn) ctorName) (replicate (length tys) NullBinder) ] (Right (preludeCtor "GT")) ] | otherwise = [] return $ CaseAlternative [ caseBinder identsL , caseBinder identsR ] (Right (appendAll tests)) : extras where caseBinder idents = ConstructorBinder (Qualified (Just mn) ctorName) (map VarBinder idents) appendAll :: [Expr] -> Expr appendAll [] = preludeCtor "EQ" appendAll [x] = x appendAll (x : xs) = Case [x] [ CaseAlternative [ ConstructorBinder (Qualified (Just (ModuleName [ProperName C.prelude])) (ProperName "LT")) [] ] (Right (preludeCtor "LT")) , CaseAlternative [ ConstructorBinder (Qualified (Just (ModuleName [ProperName C.prelude])) (ProperName "GT")) [] ] (Right (preludeCtor "GT")) , CaseAlternative [ NullBinder ] (Right (appendAll xs)) ] toOrdering :: Expr -> Expr -> Type -> Expr toOrdering l r ty | Just rec <- objectType ty = appendAll . map (\(str, typ) -> toOrdering (Accessor str l) (Accessor str r) typ) $ decomposeRec rec toOrdering l r _ = preludeCompare l r findTypeDecl :: (MonadError MultipleErrors m) => ProperName 'TypeName -> [Declaration] -> m Declaration findTypeDecl tyConNm = maybe (throwError . errorMessage $ CannotFindDerivingType tyConNm) return . find isTypeDecl where isTypeDecl :: Declaration -> Bool isTypeDecl (DataDeclaration _ nm _ _) | nm == tyConNm = True isTypeDecl (PositionedDeclaration _ _ d) = isTypeDecl d isTypeDecl _ = False lam :: Ident -> Expr -> Expr lam = Abs . Left lamNull :: Expr -> Expr lamNull = lam (Ident "$q") -- TODO: use GenIdent lamCase :: Ident -> [CaseAlternative] -> Expr lamCase s = lam s . Case [mkVar s] lamCase2 :: Ident -> Ident -> [CaseAlternative] -> Expr lamCase2 s t = lam s . lam t . Case [mkVar s, mkVar t] mkVarMn :: Maybe ModuleName -> Ident -> Expr mkVarMn mn = Var . Qualified mn mkVar :: Ident -> Expr mkVar = mkVarMn Nothing objectType :: Type -> Maybe Type objectType (TypeApp (TypeConstructor (Qualified (Just (ModuleName [ProperName "Prim"])) (ProperName "Object"))) rec) = Just rec objectType _ = Nothing decomposeRec :: Type -> [(String, Type)] decomposeRec = sortBy (comparing fst) . go where go (RCons str typ typs) = (str, typ) : decomposeRec typs go _ = []