-- | Provides all kinds of functions that are needed by the plugin. module Control.Super.Plugin.Utils ( errIndent -- * Type inspection , collectTopTyCons , collectTopTcVars , collectTopTcVarsWithArity , collectTyVars , mkTcVarSubst -- * General Utilities , skolemVarsBindFun , eqTyVar, eqTyVar' , getTyConName, getClassName , isAmbiguousType , partiallyApplyTyCons , applyTyCon , splitKindFunOfTcTv , atIndex , t1st, t2nd, t3rd , associations --, subsets , removeDup, removeDupByIndex , removeDupByIndexEq , removeDupUnique, removeDupByIndexUnique , lookupBy , allM, anyM , fromLeft, fromRight , partitionM ) where import Data.Maybe ( listToMaybe, catMaybes ) import Data.List ( find ) import qualified Data.Set as Set import qualified Data.Map.Strict as Map import Control.Monad ( forM ) import Control.Arrow ( second ) import Unique ( Uniquable ) import BasicTypes ( Arity ) import Name ( nameOccName ) import OccName ( occNameString ) import Type ( Type, TyVar , getTyVar_maybe , tyConAppTyCon_maybe , splitTyConApp_maybe, splitFunTy_maybe, splitAppTy_maybe , getEqPredTys_maybe , splitAppTys , mkTyConTy, mkTyVarTy, mkAppTys , eqType ) import TyCon ( TyCon , tyConKind, tyConName ) import Var ( tyVarKind ) import TcType ( isAmbiguousTyVar ) import Kind ( Kind ) import Class ( Class, classTyCon ) import Unify ( BindFlag(..) ) import InstEnv ( instanceBindFun ) import TcPluginM ( TcPluginM, newFlexiTyVar ) import Outputable ( ($$) ) import qualified Outputable as O import Control.Super.Plugin.Collection.Set ( Set ) import qualified Control.Super.Plugin.Collection.Set as S import qualified Control.Super.Plugin.Collection.Map as M import Control.Super.Plugin.Wrapper ( TypeVarSubst , mkTypeVarSubst , splitKindFunTys , fromLeft, fromRight ) -- ----------------------------------------------------------------------------- -- Constants -- ----------------------------------------------------------------------------- -- | Indentation to be used in error messages. errIndent :: Int errIndent = 4 -- ----------------------------------------------------------------------------- -- Constraint and type inspection -- ----------------------------------------------------------------------------- -- | Retrieve the type constructors at top level involved in the given types. -- If there are type constructors nested within the type they are ignored. -- -- /Example/: -- -- >>> collectTopTyCons [Maybe (Identity ())] -- { Maybe } collectTopTyCons :: [Type] -> Set TyCon collectTopTyCons tys = S.fromList $ catMaybes $ fmap tyConAppTyCon_maybe tys -- | Retrieve the type constructor variables at the top level involved in the -- given types. If there are nested type variables they are ignored. -- There is no actual check if the returned type variables are actually type -- constructor variables. -- -- /Example/: -- -- >>> collectTopTcVars [m a b, Identity c, n] -- { m, n } collectTopTcVars :: [Type] -> Set.Set TyVar collectTopTcVars = Set.map fst . collectTopTcVarsWithArity -- | Retrieve the type constructor variables at the top level involved in the -- given types. If there are nested type variables they are ignored. -- There is no actual check if the returned type variables are actually type -- constructor variables. Also associates the appearant arity of the given -- type variables by looking at how many arguments it was applied to. -- -- /Example/: -- -- >>> collectTopTcVars [m a b, Identity c, n] -- { (m, 2), (n, 0) } collectTopTcVarsWithArity :: [Type] -> Set.Set (TyVar, Arity) collectTopTcVarsWithArity tys = Set.fromList $ catMaybes $ fmap getTyVarAndArity tys where getTyVarAndArity :: Type -> Maybe (TyVar, Arity) getTyVarAndArity t = do let (tf, _args) = splitAppTys t tv <- getTyVar_maybe tf return (tv, tyVarArity tv) -- | Try to collect all type variables in a given expression. -- Does not work for Pi or ForAll types. -- If the given type is not supported an empty set is returned. collectTyVars :: Type -> Set.Set TyVar collectTyVars t = case getTyVar_maybe t of Just tv -> Set.singleton tv Nothing -> case splitTyConApp_maybe t of Just (_tc, args) -> Set.unions $ fmap collectTyVars args Nothing -> case splitFunTy_maybe t of Just (ta, tb) -> collectTyVars ta `Set.union` collectTyVars tb Nothing -> case splitAppTy_maybe t of Just (ta, tb) -> collectTyVars ta `Set.union` collectTyVars tb Nothing -> case getEqPredTys_maybe t of Just (_r, ta, tb) -> collectTyVars ta `Set.union` collectTyVars tb Nothing -> Set.empty -- | Create a substitution that replaces the given type variables with their -- associated type constructors. mkTcVarSubst :: [(TyVar, TyCon)] -> TypeVarSubst mkTcVarSubst substs = mkTypeVarSubst $ fmap (second mkTyConTy) substs -- ----------------------------------------------------------------------------- -- General utilities -- ----------------------------------------------------------------------------- -- | Override the standard bind flag of a given list of variables to 'Skolem'. -- The standard bind flag is determined using 'instanceBindFun'. -- This can be used to keep 'tcUnifyTys' from unifying the given variables -- and to view them as constants. skolemVarsBindFun :: [TyVar] -> TyVar -> BindFlag skolemVarsBindFun tvs var = case find (var ==) tvs of Just _ -> Skolem Nothing -> instanceBindFun var -- | Check if both types contain type variables and if those type -- variables are equal. eqTyVar :: Type -> Type -> Bool eqTyVar ty ty' = case getTyVar_maybe ty of Just tv -> eqTyVar' tv ty' _ -> False -- | Check if the given type constrains a type variable and it is equal to -- the given type variable. eqTyVar' :: TyVar -> Type -> Bool eqTyVar' tv ty = case getTyVar_maybe ty of Just tv' -> tv == tv' Nothing -> False -- | Returns the arity of a given type variable. tyVarArity :: TyVar -> Arity tyVarArity = length . fst . splitKindFunTys . tyVarKind -- | Returns the string representation of the given type constructor in the source code. getTyConName :: TyCon -> String getTyConName = occNameString . nameOccName . tyConName -- | Returns the name of the class (The literal name of its type constructor). getClassName :: Class -> String getClassName cls = getTyConName $ classTyCon cls -- | Get the element of a list at a given index (If that element exists). atIndex :: [a] -> Int -> Maybe a atIndex xs i = listToMaybe $ drop i xs -- | Select first element of triple. t1st :: (a, b, c) -> a t1st (a, _, _) = a -- | Select second element of triple. t2nd :: (a, b, c) -> b t2nd (_, b, _) = b -- | Select thrid element of triple. t3rd :: (a, b, c) -> c t3rd (_, _, c) = c -- | Checks if the given type is an ambiguous type variable. isAmbiguousType :: Type -> Bool isAmbiguousType ty = maybe False isAmbiguousTyVar $ getTyVar_maybe ty -- | Takes a list of type variables that are associated with certain type -- constructors or type constructor variables and partially applies them -- to match the kind of the type variable. Example: -- -- >>> partiallyApplyTyCons [(n :: *, Left (Maybe :: * -> *))] -- Right [(n :: *, Maybe i :: *, [i :: *])] -- -- >>> partiallyApplyTyCons [(n :: * -> *, Right (k :: * -> Int -> * -> *))] -- Right [(n :: * -> *, k i j :: * -> *, [i :: *, j :: Int])] -- -- >>> partiallyApplyTyCons [(n :: * -> *, Left (Int :: *)] -- Left ... -- -- >>> partiallyApplyTyCons [(n :: * -> *, Left (Maybe :: * -> *)] -- Right [(n :: * -> *, Maybe :: * -> *, [])] -- -- The variables generated for the partial application are flexi vars (see 'newFlexiTyVar' and 'applyTyCon'). partiallyApplyTyCons :: [(TyVar, Either TyCon TyVar)] -> TcPluginM (Either O.SDoc [(TyVar, Type, [TyVar])]) partiallyApplyTyCons [] = return $ Right [] partiallyApplyTyCons ((tv, tc) : assocs) = do let (tvKindArgs, tvKindRes) = splitKindFunOfTcTv $ Right tv let (tcKindArgs, tcKindRes) = splitKindFunOfTcTv tc let checkKindLength = length tcKindArgs >= length tvKindArgs let checkKindMatch = and (uncurry eqType <$> zip (reverse tvKindArgs) (reverse tcKindArgs)) && eqType tcKindRes tvKindRes case (checkKindLength, checkKindMatch) of (False, _) -> return $ Left $ O.text "Kind mismatch between type constructor and type variable: " $$ O.ppr tcKindArgs $$ O.ppr tvKindArgs (_, False) -> return $ Left $ O.text "Kind mismatch between type constructor and type variable: " $$ O.ppr tc $$ O.ppr tv _ -> do eAppliedAssocs <- partiallyApplyTyCons assocs case eAppliedAssocs of Left err -> return $ Left err Right appliedAssocs -> do -- Apply as many new type variables to the type constructor as are -- necessary for its kind to match that of the type variable. (appliedTcTy, argVars) <- applyTyCon (tc, take (length tcKindArgs - length tvKindArgs) tcKindArgs) return $ Right $ (tv, appliedTcTy, argVars) : appliedAssocs -- | Applies the given type constructor or type constructor variable to -- new correctly kinded variables to make it a (partially) applied type. -- The (partially) applied type is returned together with the variables -- that were applied to the type constructor. applyTyCon :: (Either TyCon TyVar, [Kind]) -> TcPluginM (Type, [TyVar]) applyTyCon (eTcTv, ks) = do tyVarArgs <- forM ks newFlexiTyVar let t = either mkTyConTy mkTyVarTy eTcTv return $ (mkAppTys t $ fmap mkTyVarTy tyVarArgs, tyVarArgs) -- | Retrieves the kind of the given type constructor or variables -- and splits it into its arguments and result. If the kind is not -- a function kind then the arguments will be empty. splitKindFunOfTcTv :: Either TyCon TyVar -> ([Kind], Kind) splitKindFunOfTcTv tc = case tc of Left tyCon -> splitKindFunTys $ tyConKind tyCon Right tyVar -> splitKindFunTys $ tyVarKind tyVar -- | Takes a list of keys and all of their possible values and returns a list -- of all possible associations between keys and values -- -- /Examples/: -- -- >>> associations [('a', [1,2,3]), ('b', [4,5])] -- [ [('a', 1), ('b', 4)], [('a', 1), ('b', 5)] -- , [('a', 2), ('b', 4)], [('a', 2), ('b', 5)] -- , [('a', 3), ('b', 4)], [('a', 3), ('b', 5)] ] associations :: [(key , [value])] -> [[(key, value)]] associations [] = [[]] associations ((_x, []) : _xys) = [] associations ((x, y : ys) : xys) = fmap ((x, y) :) (associations xys) ++ associations ((x, ys) : xys) {- -- | Generates the set of all subsets of a given set. subsets :: (Ord a) => Set a -> Set (Set a) subsets s = case S.size s of 0 -> S.singleton S.empty _ -> let (x, s') = S.deleteFindMin s subs = subsets s' in S.map (S.insert x) subs `S.union` subs -} -- | Efficient removal of duplicate elements in O(n * log(n)). -- The result list is ordered in ascending order. removeDup :: (Ord a) => [a] -> [a] removeDup = (Set.toAscList) . (Set.fromList) -- | Removal of duplicate elements in a list, -- based on their unique representation. removeDupUnique :: (Uniquable a) => [a] -> [a] removeDupUnique = (S.toList) . (S.fromList) -- | Efficient removal of duplicate entries by key in O(n * log(n)). -- The result list is ordered. removeDupByIndex :: (Ord a) => [(a,b)] -> [(a,b)] removeDupByIndex = Map.toList . Map.fromList -- | Efficient removal of duplicate entries by key, based on equality of keys. removeDupByIndexEq :: (Eq a) => [(a,b)] -> [(a,b)] removeDupByIndexEq [] = [] removeDupByIndexEq ((a,b) : l) = (a,b) : (removeDupByIndexEq $ filter (\(a',_) -> a' /= a) l) -- | Efficient removal of duplicate entries by key, -- based on the unique representation of the keys. removeDupByIndexUnique :: (Uniquable a) => [(a,b)] -> [(a,b)] removeDupByIndexUnique = M.toList . M.fromList -- | Exactly like 'lookup'. Searches the list for the entry with the right key -- and returns the associated value if an entry is found. Uses a custom -- function to check equality. lookupBy :: (a -> a -> Bool) -> a -> [(a, b)] -> Maybe b lookupBy _eq _x [] = Nothing lookupBy eq x ((y, b) : ybs) | eq x y = Just b | otherwise = lookupBy eq x ybs -- | Iterate over a list of items and check if the given predicate holds for -- all of them. allM :: (Monad m) => (a -> m Bool) -> [a] -> m Bool allM = quantM (&&) True -- | Iterate over a list of items and check if the given predicate holds for -- at least one of them. anyM :: (Monad m) => (a -> m Bool) -> [a] -> m Bool anyM = quantM (||) False -- | Generalization of 'allM' and 'anyM' that is used to implement each of those -- functions. quantM :: (Monad m) => (Bool -> Bool -> Bool) -> Bool -> (a -> m Bool) -> [a] -> m Bool quantM _comp def _p [] = return def quantM comp def p (x : xs) = do bx <- p x bxs <- quantM comp def p xs return $ bx `comp` bxs -- | Partition a list into two lists based on a predicate involving a monadic -- computation. partitionM :: (Monad m) => (a -> m Bool) -> [a] -> m ([a], [a]) partitionM _ [] = return ([], []) partitionM p (x : xs) = do (ts, fs) <- partitionM p xs b <- p x return $ if b then (x : ts, fs) else (ts, x : fs)