module Language.PureScript.Sugar.TypeClasses.Deriving (deriveInstances) where
import Prelude.Compat
import Control.Arrow (second)
import Control.Monad (replicateM, zipWithM)
import Control.Monad.Error.Class (MonadError(..))
import Control.Monad.Supply.Class (MonadSupply)
import Data.List (foldl', find, sortBy, unzip5)
import qualified Data.Map as M
import Data.Monoid ((<>))
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Ord (comparing)
import Data.Text (Text)
import qualified Data.Text as T
import Language.PureScript.AST
import qualified Language.PureScript.Constants as C
import Language.PureScript.Crash
import Language.PureScript.Environment
import Language.PureScript.Errors
import Language.PureScript.Externs
import Language.PureScript.Kinds
import Language.PureScript.Names
import Language.PureScript.Label (Label(..))
import Language.PureScript.PSString (mkString, decodeStringEither)
import Language.PureScript.Types
import Language.PureScript.TypeChecker (checkNewtype)
import Language.PureScript.TypeChecker.Synonyms (SynonymMap, replaceAllTypeSynonymsM)
deriveInstances
:: forall m
. (MonadError MultipleErrors m, MonadSupply m)
=> [ExternsFile]
-> Module
-> m Module
deriveInstances externs (Module ss coms mn ds exts) =
Module ss coms mn <$> mapM (deriveInstance mn synonyms ds) ds <*> pure exts
where
synonyms :: SynonymMap
synonyms =
M.fromList $ (externs >>= \ExternsFile{..} -> mapMaybe (fromExternsDecl efModuleName) efDeclarations)
++ mapMaybe fromLocalDecl ds
where
fromExternsDecl mn' (EDTypeSynonym name args ty) = Just (Qualified (Just mn') name, (args, ty))
fromExternsDecl _ _ = Nothing
fromLocalDecl (TypeSynonymDeclaration name args ty) = do
Just (Qualified (Just mn) name, (args, ty))
fromLocalDecl (PositionedDeclaration _ _ d) = fromLocalDecl d
fromLocalDecl _ = Nothing
deriveInstance
:: (MonadError MultipleErrors m, MonadSupply m)
=> ModuleName
-> SynonymMap
-> [Declaration]
-> Declaration
-> m Declaration
deriveInstance mn syns ds (TypeInstanceDeclaration nm deps className tys DerivedInstance)
| className == Qualified (Just dataGeneric) (ProperName C.generic)
= case tys of
[ty] | Just (Qualified mn' tyCon, args) <- unwrapTypeConstructor ty
, mn == fromMaybe mn mn'
-> TypeInstanceDeclaration nm deps className tys . ExplicitInstance <$> deriveGeneric mn syns ds tyCon args
| otherwise -> throwError . errorMessage $ ExpectedTypeConstructor className tys ty
_ -> throwError . errorMessage $ InvalidDerivedInstance className tys 1
| className == Qualified (Just dataEq) (ProperName "Eq")
= case tys of
[ty] | Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty
, mn == fromMaybe mn mn'
-> TypeInstanceDeclaration nm deps className tys . ExplicitInstance <$> deriveEq mn syns ds tyCon
| otherwise -> throwError . errorMessage $ ExpectedTypeConstructor className tys ty
_ -> throwError . errorMessage $ InvalidDerivedInstance className tys 1
| className == Qualified (Just dataOrd) (ProperName "Ord")
= case tys of
[ty] | Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty
, mn == fromMaybe mn mn'
-> TypeInstanceDeclaration nm deps className tys . ExplicitInstance <$> deriveOrd mn syns ds tyCon
| otherwise -> throwError . errorMessage $ ExpectedTypeConstructor className tys ty
_ -> throwError . errorMessage $ InvalidDerivedInstance className tys 1
| className == Qualified (Just dataFunctor) (ProperName "Functor")
= case tys of
[ty] | Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty
, mn == fromMaybe mn mn'
-> TypeInstanceDeclaration nm deps className tys . ExplicitInstance <$> deriveFunctor mn syns ds tyCon
| otherwise -> throwError . errorMessage $ ExpectedTypeConstructor className tys ty
_ -> throwError . errorMessage $ InvalidDerivedInstance className tys 1
| className == Qualified (Just dataNewtype) (ProperName "Newtype")
= case tys of
[wrappedTy, unwrappedTy]
| Just (Qualified mn' tyCon, args) <- unwrapTypeConstructor wrappedTy
, mn == fromMaybe mn mn'
-> do (inst, actualUnwrappedTy) <- deriveNewtype mn syns ds tyCon args unwrappedTy
return $ TypeInstanceDeclaration nm deps className [wrappedTy, actualUnwrappedTy] (ExplicitInstance inst)
| otherwise -> throwError . errorMessage $ ExpectedTypeConstructor className tys wrappedTy
_ -> throwError . errorMessage $ InvalidDerivedInstance className tys 2
| className == Qualified (Just dataGenericRep) (ProperName C.generic)
= case tys of
[actualTy, repTy]
| Just (Qualified mn' tyCon, args) <- unwrapTypeConstructor actualTy
, mn == fromMaybe mn mn'
-> do (inst, inferredRepTy) <- deriveGenericRep mn syns ds tyCon args repTy
return $ TypeInstanceDeclaration nm deps className [actualTy, inferredRepTy] (ExplicitInstance inst)
| otherwise -> throwError . errorMessage $ ExpectedTypeConstructor className tys actualTy
_ -> throwError . errorMessage $ InvalidDerivedInstance className tys 2
| otherwise = throwError . errorMessage $ CannotDerive className tys
deriveInstance mn syns ds (TypeInstanceDeclaration nm deps className tys NewtypeInstance) =
case tys of
_ : _ | Just (Qualified mn' tyCon, args) <- unwrapTypeConstructor (last tys)
, mn == fromMaybe mn mn'
-> TypeInstanceDeclaration nm deps className tys . NewtypeInstanceWithDictionary <$> deriveNewtypeInstance syns className ds tys tyCon args
| otherwise -> throwError . errorMessage $ ExpectedTypeConstructor className tys (last tys)
_ -> throwError . errorMessage $ InvalidNewtypeInstance className tys
deriveInstance mn syns ds (PositionedDeclaration pos com d) = PositionedDeclaration pos com <$> deriveInstance mn syns ds d
deriveInstance _ _ _ e = return e
unwrapTypeConstructor :: Type -> Maybe (Qualified (ProperName 'TypeName), [Type])
unwrapTypeConstructor = fmap (second reverse) . go
where
go (TypeConstructor tyCon) = Just (tyCon, [])
go (TypeApp ty arg) = do
(tyCon, args) <- go ty
return (tyCon, arg : args)
go _ = Nothing
deriveNewtypeInstance
:: forall m
. MonadError MultipleErrors m
=> SynonymMap
-> Qualified (ProperName 'ClassName)
-> [Declaration]
-> [Type]
-> ProperName 'TypeName
-> [Type]
-> m Expr
deriveNewtypeInstance syns className ds tys tyConNm dargs = do
tyCon <- findTypeDecl tyConNm ds
go tyCon
where
go (DataDeclaration Newtype _ tyArgNames [(_, [wrapped])])
| Just wrapped' <- stripRight (takeReverse (length tyArgNames length dargs) tyArgNames) wrapped =
do let subst = zipWith (\(name, _) t -> (name, t)) tyArgNames dargs
wrapped'' <- replaceAllTypeSynonymsM syns wrapped'
return (DeferredDictionary className (init tys ++ [replaceAllTypeVars subst wrapped'']))
go (PositionedDeclaration _ _ d) = go d
go _ = throwError . errorMessage $ InvalidNewtypeInstance className tys
takeReverse :: Int -> [a] -> [a]
takeReverse n = take n . reverse
stripRight :: [(Text, Maybe kind)] -> Type -> Maybe Type
stripRight [] ty = Just ty
stripRight ((arg, _) : args) (TypeApp t (TypeVar arg'))
| arg == arg' = stripRight args t
stripRight _ _ = Nothing
dataGeneric :: ModuleName
dataGeneric = ModuleName [ ProperName "Data", ProperName "Generic" ]
dataGenericRep :: ModuleName
dataGenericRep = ModuleName [ ProperName "Data", ProperName "Generic", ProperName "Rep" ]
dataMaybe :: ModuleName
dataMaybe = ModuleName [ ProperName "Data", ProperName "Maybe" ]
typesProxy :: ModuleName
typesProxy = ModuleName [ ProperName "Type", ProperName "Proxy" ]
dataEq :: ModuleName
dataEq = ModuleName [ ProperName "Data", ProperName "Eq" ]
dataOrd :: ModuleName
dataOrd = ModuleName [ ProperName "Data", ProperName "Ord" ]
dataNewtype :: ModuleName
dataNewtype = ModuleName [ ProperName "Data", ProperName "Newtype" ]
dataFunctor :: ModuleName
dataFunctor = ModuleName [ ProperName "Data", ProperName "Functor" ]
unguarded :: Expr -> [GuardedExpr]
unguarded e = [MkUnguarded e]
deriveGeneric
:: forall m. (MonadError MultipleErrors m, MonadSupply m)
=> ModuleName
-> SynonymMap
-> [Declaration]
-> ProperName 'TypeName
-> [Type]
-> m [Declaration]
deriveGeneric mn syns ds tyConNm dargs = do
tyCon <- findTypeDecl tyConNm ds
toSpine <- mkSpineFunction tyCon
fromSpine <- mkFromSpineFunction tyCon
toSignature <- mkSignatureFunction tyCon dargs
return [ ValueDeclaration (Ident C.toSpine) Public [] (unguarded toSpine)
, ValueDeclaration (Ident C.fromSpine) Public [] (unguarded fromSpine)
, ValueDeclaration (Ident C.toSignature) Public [] (unguarded toSignature)
]
where
mkSpineFunction :: Declaration -> m Expr
mkSpineFunction (DataDeclaration _ _ _ args) = do
x <- freshIdent'
lamCase x <$> mapM mkCtorClause args
where
prodConstructor :: Expr -> Expr
prodConstructor = App (Constructor (Qualified (Just dataGeneric) (ProperName "SProd")))
recordConstructor :: Expr -> Expr
recordConstructor = App (Constructor (Qualified (Just dataGeneric) (ProperName "SRecord")))
mkCtorClause :: (ProperName 'ConstructorName, [Type]) -> m CaseAlternative
mkCtorClause (ctorName, tys) = do
idents <- replicateM (length tys) freshIdent'
tys' <- mapM (replaceAllTypeSynonymsM syns) tys
let caseResult =
App (prodConstructor (Literal . StringLiteral . mkString . showQualified runProperName $ Qualified (Just mn) ctorName))
. Literal . ArrayLiteral
$ zipWith toSpineFun (map (Var . Qualified Nothing) idents) tys'
return $ CaseAlternative [ConstructorBinder (Qualified (Just mn) ctorName) (map VarBinder idents)] (unguarded caseResult)
toSpineFun :: Expr -> Type -> Expr
toSpineFun i r | Just rec <- objectType r
, Just fields <- decomposeRec rec =
lamNull . recordConstructor . Literal . ArrayLiteral
. map
(\((Label str),typ) ->
Literal $ ObjectLiteral
[ ("recLabel", Literal (StringLiteral str))
, ("recValue", toSpineFun (Accessor str i) typ)
]
)
$ fields
toSpineFun i _ = lamNull $ App (mkGenVar (Ident C.toSpine)) i
mkSpineFunction (PositionedDeclaration _ _ d) = mkSpineFunction d
mkSpineFunction _ = internalError "mkSpineFunction: expected DataDeclaration"
mkSignatureFunction :: Declaration -> [Type] -> m Expr
mkSignatureFunction (DataDeclaration _ name tyArgs args) classArgs = lamNull . mkSigProd <$> mapM mkProdClause args
where
mkSigProd :: [Expr] -> Expr
mkSigProd =
App
(App
(Constructor (Qualified (Just dataGeneric) (ProperName "SigProd")))
(Literal (StringLiteral $ mkString (showQualified runProperName (Qualified (Just mn) name))))
)
. Literal
. ArrayLiteral
mkSigRec :: [Expr] -> Expr
mkSigRec = App (Constructor (Qualified (Just dataGeneric) (ProperName "SigRecord"))) . Literal . ArrayLiteral
proxy :: Type -> Type
proxy = TypeApp (TypeConstructor (Qualified (Just typesProxy) (ProperName "Proxy")))
mkProdClause :: (ProperName 'ConstructorName, [Type]) -> m Expr
mkProdClause (ctorName, tys) = do
tys' <- mapM (replaceAllTypeSynonymsM syns) tys
return $ Literal $ ObjectLiteral
[ ("sigConstructor", Literal (StringLiteral $ mkString (showQualified runProperName (Qualified (Just mn) ctorName))))
, ("sigValues", Literal . ArrayLiteral . map (mkProductSignature . instantiate) $ tys')
]
mkProductSignature :: Type -> Expr
mkProductSignature r | Just rec <- objectType r
, Just fields <- decomposeRec rec =
lamNull . mkSigRec $
[ Literal $ ObjectLiteral
[ ("recLabel", Literal (StringLiteral str))
, ("recValue", mkProductSignature typ)
]
| ((Label str), typ) <- fields
]
mkProductSignature typ = lamNull $ App (mkGenVar (Ident C.toSignature))
(TypedValue False (mkGenVar (Ident "anyProxy")) (proxy typ))
instantiate = replaceAllTypeVars (zipWith (\(arg, _) ty -> (arg, ty)) tyArgs classArgs)
mkSignatureFunction (PositionedDeclaration _ _ d) classArgs = mkSignatureFunction d classArgs
mkSignatureFunction _ _ = internalError "mkSignatureFunction: expected DataDeclaration"
mkFromSpineFunction :: Declaration -> m Expr
mkFromSpineFunction (DataDeclaration _ _ _ args) = do
x <- freshIdent'
lamCase x <$> (addCatch <$> mapM mkAlternative args)
where
mkJust :: Expr -> Expr
mkJust = App (Constructor (Qualified (Just dataMaybe) (ProperName "Just")))
mkNothing :: Expr
mkNothing = Constructor (Qualified (Just dataMaybe) (ProperName "Nothing"))
prodBinder :: [Binder] -> Binder
prodBinder = ConstructorBinder (Qualified (Just dataGeneric) (ProperName "SProd"))
recordBinder :: [Binder] -> Binder
recordBinder = ConstructorBinder (Qualified (Just dataGeneric) (ProperName "SRecord"))
mkAlternative :: (ProperName 'ConstructorName, [Type]) -> m CaseAlternative
mkAlternative (ctorName, tys) = do
idents <- replicateM (length tys) freshIdent'
tys' <- mapM (replaceAllTypeSynonymsM syns) tys
return $
CaseAlternative
[ prodBinder
[ LiteralBinder (StringLiteral $ mkString (showQualified runProperName (Qualified (Just mn) ctorName)))
, LiteralBinder (ArrayLiteral (map VarBinder idents))
]
]
. unguarded
$ liftApplicative
(mkJust $ Constructor (Qualified (Just mn) ctorName))
(zipWith fromSpineFun (map (Var . Qualified Nothing) idents) tys')
addCatch :: [CaseAlternative] -> [CaseAlternative]
addCatch = (++ [catchAll])
where
catchAll = CaseAlternative [NullBinder] (unguarded mkNothing)
fromSpineFun :: Expr -> Type -> Expr
fromSpineFun e r
| Just rec <- objectType r
, Just fields <- decomposeRec rec
= App (lamCase (Ident "r") [ mkRecCase fields
, CaseAlternative [NullBinder] (unguarded mkNothing)
])
(App e unitVal)
fromSpineFun e _ = App (mkGenVar (Ident C.fromSpine)) (App e unitVal)
mkRecCase :: [(Label, Type)] -> CaseAlternative
mkRecCase rs =
CaseAlternative
[ recordBinder [ LiteralBinder (ArrayLiteral (map (VarBinder . labelToIdent . fst) rs)) ] ]
. unguarded
$ liftApplicative (mkRecFun rs) (map (\(x, y) -> fromSpineFun (Accessor "recValue" (mkVar $ labelToIdent x)) y) rs)
mkRecFun :: [(Label, Type)] -> Expr
mkRecFun xs = mkJust $ foldr (lam . labelToIdent . fst) recLiteral xs
where recLiteral = Literal . ObjectLiteral $ map (\(l@(Label s), _) -> (s, mkVar $ labelToIdent l)) xs
mkFromSpineFunction (PositionedDeclaration _ _ d) = mkFromSpineFunction d
mkFromSpineFunction _ = internalError "mkFromSpineFunction: expected DataDeclaration"
liftApplicative :: Expr -> [Expr] -> Expr
liftApplicative = foldl' (\x e -> App (App applyFn x) e)
unitVal :: Expr
unitVal = mkVarMn (Just (ModuleName [ProperName "Data", ProperName "Unit"])) (Ident "unit")
applyFn :: Expr
applyFn = mkVarMn (Just (ModuleName [ProperName "Control", ProperName "Apply"])) (Ident "apply")
mkGenVar :: Ident -> Expr
mkGenVar = mkVarMn (Just (ModuleName [ProperName "Data", ProperName C.generic]))
deriveGenericRep
:: forall m
. (MonadError MultipleErrors m, MonadSupply m)
=> ModuleName
-> SynonymMap
-> [Declaration]
-> ProperName 'TypeName
-> [Type]
-> Type
-> m ([Declaration], Type)
deriveGenericRep mn syns ds tyConNm tyConArgs repTy = do
checkIsWildcard tyConNm repTy
go =<< findTypeDecl tyConNm ds
where
go :: Declaration -> m ([Declaration], Type)
go (DataDeclaration _ _ args dctors) = do
x <- freshIdent "x"
(reps, to, from) <- unzip3 <$> traverse makeInst dctors
let rep = toRepTy reps
inst | null reps =
[ ValueDeclaration (Ident "to") Public [] $ unguarded $
lamCase x [ CaseAlternative [NullBinder]
(unguarded (App toName (Var (Qualified Nothing x))))
]
, ValueDeclaration (Ident "from") Public [] $ unguarded $
lamCase x [ CaseAlternative [NullBinder]
(unguarded (App fromName (Var (Qualified Nothing x))))
]
]
| otherwise =
[ ValueDeclaration (Ident "to") Public [] $ unguarded $
lamCase x (zipWith ($) (map underBinder (sumBinders (length dctors))) to)
, ValueDeclaration (Ident "from") Public [] $ unguarded $
lamCase x (zipWith ($) (map underExpr (sumExprs (length dctors))) from)
]
subst = zipWith ((,) . fst) args tyConArgs
return (inst, replaceAllTypeVars subst rep)
go (PositionedDeclaration _ _ d) = go d
go _ = internalError "deriveGenericRep go: expected DataDeclaration"
select :: (a -> a) -> (a -> a) -> Int -> [a -> a]
select _ _ 0 = []
select _ _ 1 = [id]
select l r n = take (n 1) (iterate (r .) l) ++ [compN (n 1) r]
sumBinders :: Int -> [Binder -> Binder]
sumBinders = select (ConstructorBinder inl . pure) (ConstructorBinder inr . pure)
sumExprs :: Int -> [Expr -> Expr]
sumExprs = select (App (Constructor inl)) (App (Constructor inr))
compN :: Int -> (a -> a) -> a -> a
compN 0 _ = id
compN n f = f . compN (n 1) f
makeInst
:: (ProperName 'ConstructorName, [Type])
-> m (Type, CaseAlternative, CaseAlternative)
makeInst (ctorName, args) = do
args' <- mapM (replaceAllTypeSynonymsM syns) args
(ctorTy, matchProduct, ctorArgs, matchCtor, mkProduct) <- makeProduct args'
return ( TypeApp (TypeApp (TypeConstructor constructor)
(TypeLevelString $ mkString (runProperName ctorName)))
ctorTy
, CaseAlternative [ ConstructorBinder constructor [matchProduct] ]
(unguarded (foldl App (Constructor (Qualified (Just mn) ctorName)) ctorArgs))
, CaseAlternative [ ConstructorBinder (Qualified (Just mn) ctorName) matchCtor ]
(unguarded (constructor' mkProduct))
)
makeProduct
:: [Type]
-> m (Type, Binder, [Expr], [Binder], Expr)
makeProduct [] =
pure (noArgs, NullBinder, [], [], noArgs')
makeProduct args = do
(tys, bs1, es1, bs2, es2) <- unzip5 <$> traverse makeArg args
pure ( foldr1 (\f -> TypeApp (TypeApp (TypeConstructor productName) f)) tys
, foldr1 (\b1 b2 -> ConstructorBinder productName [b1, b2]) bs1
, es1
, bs2
, foldr1 (\e1 -> App (App (Constructor productName) e1)) es2
)
makeArg :: Type -> m (Type, Binder, Expr, Binder, Expr)
makeArg arg | Just rec <- objectType arg
, Just fields <- decomposeRec rec = do
fieldNames <- traverse freshIdent (map (runIdent . labelToIdent . fst) fields)
case fieldNames of
[] ->
pure ( TypeApp (TypeConstructor record) noArgs
, ConstructorBinder record [ NullBinder ]
, Literal (ObjectLiteral [])
, LiteralBinder (ObjectLiteral [])
, record' noArgs'
)
_ ->
pure ( TypeApp (TypeConstructor record)
(foldr1 (\f -> TypeApp (TypeApp (TypeConstructor productName) f))
(map (\((Label name), ty) ->
TypeApp (TypeApp (TypeConstructor field) (TypeLevelString name)) ty) fields))
, ConstructorBinder record
[ foldr1 (\b1 b2 -> ConstructorBinder productName [b1, b2])
(map (\ident -> ConstructorBinder field [VarBinder ident]) fieldNames)
]
, Literal . ObjectLiteral $
zipWith (\((Label name), _) ident -> (name, Var (Qualified Nothing ident))) fields fieldNames
, LiteralBinder . ObjectLiteral $
zipWith (\((Label name), _) ident -> (name, VarBinder ident)) fields fieldNames
, record' $
foldr1 (\e1 -> App (App (Constructor productName) e1))
(map (field' . Var . Qualified Nothing) fieldNames)
)
makeArg arg = do
argName <- freshIdent "arg"
pure ( TypeApp (TypeConstructor argument) arg
, ConstructorBinder argument [ VarBinder argName ]
, Var (Qualified Nothing argName)
, VarBinder argName
, argument' (Var (Qualified Nothing argName))
)
underBinder :: (Binder -> Binder) -> CaseAlternative -> CaseAlternative
underBinder f (CaseAlternative bs e) = CaseAlternative (map f bs) e
underExpr :: (Expr -> Expr) -> CaseAlternative -> CaseAlternative
underExpr f (CaseAlternative b [MkUnguarded e]) = CaseAlternative b (unguarded (f e))
underExpr _ _ = internalError "underExpr: expected unguarded alternative"
toRepTy :: [Type] -> Type
toRepTy [] = noCtors
toRepTy [only] = only
toRepTy ctors = foldr1 (\f -> TypeApp (TypeApp sumCtor f)) ctors
toName :: Expr
toName = Var (Qualified (Just dataGenericRep) (Ident "to"))
fromName :: Expr
fromName = Var (Qualified (Just dataGenericRep) (Ident "from"))
noCtors :: Type
noCtors = TypeConstructor (Qualified (Just dataGenericRep) (ProperName "NoConstructors"))
noArgs :: Type
noArgs = TypeConstructor (Qualified (Just dataGenericRep) (ProperName "NoArguments"))
noArgs' :: Expr
noArgs' = Constructor (Qualified (Just dataGenericRep) (ProperName "NoArguments"))
sumCtor :: Type
sumCtor = TypeConstructor (Qualified (Just dataGenericRep) (ProperName "Sum"))
inl :: Qualified (ProperName 'ConstructorName)
inl = Qualified (Just dataGenericRep) (ProperName "Inl")
inr :: Qualified (ProperName 'ConstructorName)
inr = Qualified (Just dataGenericRep) (ProperName "Inr")
productName :: Qualified (ProperName ty)
productName = Qualified (Just dataGenericRep) (ProperName "Product")
constructor :: Qualified (ProperName ty)
constructor = Qualified (Just dataGenericRep) (ProperName "Constructor")
constructor' :: Expr -> Expr
constructor' = App (Constructor constructor)
argument :: Qualified (ProperName ty)
argument = Qualified (Just dataGenericRep) (ProperName "Argument")
argument' :: Expr -> Expr
argument' = App (Constructor argument)
record :: Qualified (ProperName ty)
record = Qualified (Just dataGenericRep) (ProperName "Rec")
record' :: Expr -> Expr
record' = App (Constructor record)
field :: Qualified (ProperName ty)
field = Qualified (Just dataGenericRep) (ProperName "Field")
field' :: Expr -> Expr
field' = App (Constructor field)
checkIsWildcard :: MonadError MultipleErrors m => ProperName 'TypeName -> Type -> m ()
checkIsWildcard _ (TypeWildcard _) = return ()
checkIsWildcard tyConNm _ =
throwError . errorMessage $ ExpectedWildcard tyConNm
deriveEq ::
forall m. (MonadError MultipleErrors m, MonadSupply m)
=> ModuleName
-> SynonymMap
-> [Declaration]
-> ProperName 'TypeName
-> m [Declaration]
deriveEq mn syns ds tyConNm = do
tyCon <- findTypeDecl tyConNm ds
eqFun <- mkEqFunction tyCon
return [ ValueDeclaration (Ident C.eq) Public [] (unguarded eqFun) ]
where
mkEqFunction :: Declaration -> m Expr
mkEqFunction (DataDeclaration _ _ _ args) = do
x <- freshIdent "x"
y <- freshIdent "y"
lamCase2 x y <$> (addCatch <$> mapM mkCtorClause args)
mkEqFunction (PositionedDeclaration _ _ d) = mkEqFunction d
mkEqFunction _ = internalError "mkEqFunction: expected DataDeclaration"
preludeConj :: Expr -> Expr -> Expr
preludeConj = App . App (Var (Qualified (Just (ModuleName [ProperName "Data", ProperName "HeytingAlgebra"])) (Ident C.conj)))
preludeEq :: Expr -> Expr -> Expr
preludeEq = App . App (Var (Qualified (Just dataEq) (Ident C.eq)))
addCatch :: [CaseAlternative] -> [CaseAlternative]
addCatch xs
| length xs /= 1 = xs ++ [catchAll]
| otherwise = xs
where
catchAll = CaseAlternative [NullBinder, NullBinder] (unguarded (Literal (BooleanLiteral False)))
mkCtorClause :: (ProperName 'ConstructorName, [Type]) -> m CaseAlternative
mkCtorClause (ctorName, tys) = do
identsL <- replicateM (length tys) (freshIdent "l")
identsR <- replicateM (length tys) (freshIdent "r")
tys' <- mapM (replaceAllTypeSynonymsM syns) tys
let tests = zipWith3 toEqTest (map (Var . Qualified Nothing) identsL) (map (Var . Qualified Nothing) identsR) tys'
return $ CaseAlternative [caseBinder identsL, caseBinder identsR] (unguarded (conjAll tests))
where
caseBinder idents = ConstructorBinder (Qualified (Just mn) ctorName) (map VarBinder idents)
conjAll :: [Expr] -> Expr
conjAll [] = Literal (BooleanLiteral True)
conjAll xs = foldl1 preludeConj xs
toEqTest :: Expr -> Expr -> Type -> Expr
toEqTest l r ty | Just rec <- objectType ty
, Just fields <- decomposeRec rec =
conjAll
. map (\((Label str), typ) -> toEqTest (Accessor str l) (Accessor str r) typ)
$ fields
toEqTest l r _ = preludeEq l r
deriveOrd ::
forall m. (MonadError MultipleErrors m, MonadSupply m)
=> ModuleName
-> SynonymMap
-> [Declaration]
-> ProperName 'TypeName
-> m [Declaration]
deriveOrd mn syns ds tyConNm = do
tyCon <- findTypeDecl tyConNm ds
compareFun <- mkCompareFunction tyCon
return [ ValueDeclaration (Ident C.compare) Public [] (unguarded compareFun) ]
where
mkCompareFunction :: Declaration -> m Expr
mkCompareFunction (DataDeclaration _ _ _ args) = do
x <- freshIdent "x"
y <- freshIdent "y"
lamCase2 x y <$> (addCatch . concat <$> mapM mkCtorClauses (splitLast args))
mkCompareFunction (PositionedDeclaration _ _ d) = mkCompareFunction d
mkCompareFunction _ = internalError "mkCompareFunction: expected DataDeclaration"
splitLast :: [a] -> [(a, Bool)]
splitLast [] = []
splitLast [x] = [(x, True)]
splitLast (x : xs) = (x, False) : splitLast xs
addCatch :: [CaseAlternative] -> [CaseAlternative]
addCatch xs
| null xs = [catchAll]
| otherwise = xs
where
catchAll = CaseAlternative [NullBinder, NullBinder] (unguarded (orderingCtor "EQ"))
orderingName :: Text -> Qualified (ProperName a)
orderingName = Qualified (Just (ModuleName [ProperName "Data", ProperName "Ordering"])) . ProperName
orderingCtor :: Text -> Expr
orderingCtor = Constructor . orderingName
orderingBinder :: Text -> Binder
orderingBinder name = ConstructorBinder (orderingName name) []
ordCompare :: Expr -> Expr -> Expr
ordCompare = App . App (Var (Qualified (Just dataOrd) (Ident C.compare)))
mkCtorClauses :: ((ProperName 'ConstructorName, [Type]), Bool) -> m [CaseAlternative]
mkCtorClauses ((ctorName, tys), isLast) = do
identsL <- replicateM (length tys) (freshIdent "l")
identsR <- replicateM (length tys) (freshIdent "r")
tys' <- mapM (replaceAllTypeSynonymsM syns) tys
let tests = zipWith3 toOrdering (map (Var . Qualified Nothing) identsL) (map (Var . Qualified Nothing) identsR) tys'
extras | not isLast = [ CaseAlternative [ ConstructorBinder (Qualified (Just mn) ctorName) (replicate (length tys) NullBinder)
, NullBinder
]
(unguarded (orderingCtor "LT"))
, CaseAlternative [ NullBinder
, ConstructorBinder (Qualified (Just mn) ctorName) (replicate (length tys) NullBinder)
]
(unguarded (orderingCtor "GT"))
]
| otherwise = []
return $ CaseAlternative [ caseBinder identsL
, caseBinder identsR
]
(unguarded (appendAll tests))
: extras
where
caseBinder idents = ConstructorBinder (Qualified (Just mn) ctorName) (map VarBinder idents)
appendAll :: [Expr] -> Expr
appendAll [] = orderingCtor "EQ"
appendAll [x] = x
appendAll (x : xs) = Case [x] [ CaseAlternative [orderingBinder "LT"]
(unguarded (orderingCtor "LT"))
, CaseAlternative [orderingBinder "GT"]
(unguarded (orderingCtor "GT"))
, CaseAlternative [ NullBinder ]
(unguarded (appendAll xs))
]
toOrdering :: Expr -> Expr -> Type -> Expr
toOrdering l r ty | Just rec <- objectType ty
, Just fields <- decomposeRec rec =
appendAll
. map (\((Label str), typ) -> toOrdering (Accessor str l) (Accessor str r) typ)
$ fields
toOrdering l r _ = ordCompare l r
deriveNewtype
:: forall m
. (MonadError MultipleErrors m, MonadSupply m)
=> ModuleName
-> SynonymMap
-> [Declaration]
-> ProperName 'TypeName
-> [Type]
-> Type
-> m ([Declaration], Type)
deriveNewtype mn syns ds tyConNm tyConArgs unwrappedTy = do
checkIsWildcard tyConNm unwrappedTy
go =<< findTypeDecl tyConNm ds
where
go :: Declaration -> m ([Declaration], Type)
go (DataDeclaration Data name _ _) =
throwError . errorMessage $ CannotDeriveNewtypeForData name
go (DataDeclaration Newtype name args dctors) = do
checkNewtype name dctors
wrappedIdent <- freshIdent "n"
unwrappedIdent <- freshIdent "a"
let (ctorName, [ty]) = head dctors
ty' <- replaceAllTypeSynonymsM syns ty
let inst =
[ ValueDeclaration (Ident "wrap") Public [] $ unguarded $
Constructor (Qualified (Just mn) ctorName)
, ValueDeclaration (Ident "unwrap") Public [] $ unguarded $
lamCase wrappedIdent
[ CaseAlternative
[ConstructorBinder (Qualified (Just mn) ctorName) [VarBinder unwrappedIdent]]
(unguarded (Var (Qualified Nothing unwrappedIdent)))
]
]
subst = zipWith ((,) . fst) args tyConArgs
return (inst, replaceAllTypeVars subst ty')
go (PositionedDeclaration _ _ d) = go d
go _ = internalError "deriveNewtype go: expected DataDeclaration"
findTypeDecl
:: (MonadError MultipleErrors m)
=> ProperName 'TypeName
-> [Declaration]
-> m Declaration
findTypeDecl tyConNm = maybe (throwError . errorMessage $ CannotFindDerivingType tyConNm) return . find isTypeDecl
where
isTypeDecl :: Declaration -> Bool
isTypeDecl (DataDeclaration _ nm _ _) | nm == tyConNm = True
isTypeDecl (PositionedDeclaration _ _ d) = isTypeDecl d
isTypeDecl _ = False
lam :: Ident -> Expr -> Expr
lam = Abs . VarBinder
lamNull :: Expr -> Expr
lamNull = lam (Ident "$q")
lamCase :: Ident -> [CaseAlternative] -> Expr
lamCase s = lam s . Case [mkVar s]
lamCase2 :: Ident -> Ident -> [CaseAlternative] -> Expr
lamCase2 s t = lam s . lam t . Case [mkVar s, mkVar t]
mkVarMn :: Maybe ModuleName -> Ident -> Expr
mkVarMn mn = Var . Qualified mn
mkVar :: Ident -> Expr
mkVar = mkVarMn Nothing
labelToIdent :: Label -> Ident
labelToIdent =
Ident . foldMap (either loneSurrogate char) . decodeStringEither . runLabel
where
char '_' = "__"
char c = T.singleton c
loneSurrogate x = "_" <> T.pack (show x) <> "_"
objectType :: Type -> Maybe Type
objectType (TypeApp (TypeConstructor (Qualified (Just (ModuleName [ProperName "Prim"])) (ProperName "Record"))) rec) = Just rec
objectType _ = Nothing
decomposeRec :: Type -> Maybe [(Label, Type)]
decomposeRec = fmap (sortBy (comparing fst)) . go
where go (RCons str typ typs) = fmap ((str, typ) :) (go typs)
go REmpty = Just []
go _ = Nothing
decomposeRec' :: Type -> [(Label, Type)]
decomposeRec' = sortBy (comparing fst) . go
where go (RCons str typ typs) = (str, typ) : go typs
go _ = []
deriveFunctor
:: forall m
. (MonadError MultipleErrors m, MonadSupply m)
=> ModuleName
-> SynonymMap
-> [Declaration]
-> ProperName 'TypeName
-> m [Declaration]
deriveFunctor mn syns ds tyConNm = do
tyCon <- findTypeDecl tyConNm ds
mapFun <- mkMapFunction tyCon
return [ ValueDeclaration (Ident C.map) Public [] (unguarded mapFun) ]
where
mkMapFunction :: Declaration -> m Expr
mkMapFunction (DataDeclaration _ _ tys ctors) = case reverse tys of
[] -> throwError . errorMessage $ KindsDoNotUnify (FunKind kindType kindType) kindType
((iTy, _) : _) -> do
f <- freshIdent "f"
m <- freshIdent "m"
lam f . lamCase m <$> mapM (mkCtorClause iTy f) ctors
mkMapFunction (PositionedDeclaration _ _ d) = mkMapFunction d
mkMapFunction _ = internalError "mkMapFunction: expected DataDeclaration"
mkCtorClause :: Text -> Ident -> (ProperName 'ConstructorName, [Type]) -> m CaseAlternative
mkCtorClause iTyName f (ctorName, ctorTys) = do
idents <- replicateM (length ctorTys) (freshIdent "v")
ctorTys' <- mapM (replaceAllTypeSynonymsM syns) ctorTys
args <- zipWithM transformArg idents ctorTys'
let ctor = Constructor (Qualified (Just mn) ctorName)
rebuilt = foldl App ctor args
caseBinder = ConstructorBinder (Qualified (Just mn) ctorName) (VarBinder <$> idents)
return $ CaseAlternative [caseBinder] (unguarded rebuilt)
where
fVar = mkVar f
mapVar = mkVarMn (Just dataFunctor) (Ident C.map)
transformArg :: Ident -> Type -> m Expr
transformArg ident = fmap (foldr App (mkVar ident)) . goType where
goType :: Type -> m (Maybe Expr)
goType (TypeVar t) | t == iTyName = return (Just fVar)
goType recTy | Just row <- objectType recTy =
traverse buildUpdate (decomposeRec' row) >>= (traverse buildRecord . justUpdates)
where
justUpdates :: [Maybe (Label, Expr)] -> Maybe [(Label, Expr)]
justUpdates = foldMap (fmap return)
buildUpdate :: (Label, Type) -> m (Maybe (Label, Expr))
buildUpdate (lbl, ty) = do upd <- goType ty
return ((lbl,) <$> upd)
buildRecord :: [(Label, Expr)] -> m Expr
buildRecord updates = do arg <- freshIdent "o"
let argVar = mkVar arg
mkAssignment ((Label l), x) = (l, App x (Accessor l argVar))
return (lam arg (ObjectUpdate argVar (mkAssignment <$> updates)))
goType (TypeApp _ t) = fmap (App mapVar) <$> goType t
goType _ = return Nothing