{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, UndecidableInstances #-} module Language.Haskell.Names.SyntaxUtils ( dropAnn , setAnn , getModuleName , getImports , getExportSpecList , splitDeclHead , getDeclHeadName , getModuleDecls , isTypeDecl , GetBound(..) , opName , isCon , nameToString , specialConToString , qNameToName , unCName , getErrors -- export ExtensionSet here for the outside users , ExtensionSet , moduleExtensions ) where import Prelude hiding (concatMap) import Data.Char import Data.Data import Data.Maybe import Data.Either import Data.Generics.Uniplate.Data import Data.Foldable import qualified Data.Set as Set import Language.Haskell.Exts.Annotated import Language.Haskell.Names.Types dropAnn :: (Functor a) => a l -> a () dropAnn = fmap (const ()) setAnn :: (Functor a) => l' -> a l -> a l' setAnn l = fmap (const l) getModuleName :: Module l -> ModuleName l getModuleName (Module _ (Just (ModuleHead _ mn _ _)) _ _ _) = mn getModuleName (XmlPage _ mn _ _ _ _ _) = mn getModuleName (XmlHybrid _ (Just (ModuleHead _ mn _ _)) _ _ _ _ _ _ _) = mn getModuleName m = main_mod (ann m) getImports :: Module l -> [ImportDecl l] getImports (Module _ _ _ is _) = is getImports (XmlPage _ _ _ _ _ _ _) = [] getImports (XmlHybrid _ _ _ is _ _ _ _ _) = is getModuleDecls :: Module l -> [Decl l] getModuleDecls (Module _ _ _ _ ds) = ds getModuleDecls (XmlPage _ _ _ _ _ _ _) = [] getModuleDecls (XmlHybrid _ _ _ _ ds _ _ _ _) = ds getExportSpecList :: Module l -> Maybe (ExportSpecList l) getExportSpecList m = me where ModuleHead _ _ _ me = getModuleHead m getModuleHead :: Module l -> ModuleHead l getModuleHead (Module _ (Just mh) _ _ _) = mh getModuleHead (XmlHybrid _ (Just mh) _ _ _ _ _ _ _) = mh getModuleHead m = ModuleHead l (main_mod l) Nothing (Just (ExportSpecList l [EVar l (UnQual l (Ident l "main"))])) where l = ann m qNameToName :: QName l -> Name l qNameToName (UnQual _ n) = n qNameToName (Qual _ _ n) = n qNameToName (Special l s) = Ident l (specialConToString s) {- getImportDecls :: Module l -> [ImportDecl l] getImportDecls (Module _ _ _ is _) = is getImportDecls (XmlPage _ _ _ _ _ _ _) = [] getImportDecls (XmlHybrid _ _ _ is _ _ _ _ _) = is -} getDeclHead :: Decl l -> Maybe (DeclHead l) getDeclHead (TypeDecl _ dhead _) = Just dhead getDeclHead (TypeFamDecl _ dhead _) = Just dhead getDeclHead (DataDecl _ _ _ dhead _ _) = Just dhead getDeclHead (GDataDecl _ _ _ dhead _ _ _) = Just dhead getDeclHead (DataFamDecl _ _ dhead _) = Just dhead getDeclHead (ClassDecl _ _ dhead _ _) = Just dhead getDeclHead _ = Nothing splitDeclHead :: DeclHead l -> (Name l, [TyVarBind l]) splitDeclHead (DHead _ n vs) = (n, vs) splitDeclHead (DHInfix _ v1 n v2) = (n, [v1, v2]) splitDeclHead (DHParen _ dhead) = splitDeclHead dhead getDeclHeadName :: Decl l -> Name l getDeclHeadName = fst . splitDeclHead . fromMaybe (error "getDeclHeadName") . getDeclHead ---------------------------------------------------- -- Get bound value identifiers. class GetBound a l | a -> l where getBound :: a -> [Name l] -- XXX account for shadowing? instance (GetBound a l) => GetBound [a] l where getBound xs = concatMap getBound xs instance (GetBound a l) => GetBound (Maybe a) l where getBound Nothing = [] getBound (Just x) = getBound x instance (GetBound a l, GetBound b l) => GetBound (a, b) l where getBound (a, b) = getBound a ++ getBound b instance (Data l) => GetBound (Binds l) l where getBound (BDecls _ ds) = getBound ds getBound (IPBinds _ _) = [] -- XXX doesn't bind regular identifiers instance (Data l) => GetBound (Decl l) l where getBound (TypeDecl{}) = [] getBound (TypeFamDecl{}) = [] getBound (DataDecl _ _ _ _ ds _) = getBound ds getBound (GDataDecl _ _ _ _ _ ds _) = getBound ds getBound (DataFamDecl{}) = [] getBound (TypeInsDecl{}) = [] getBound (DataInsDecl _ _ _ ds _) = getBound ds getBound (GDataInsDecl _ _ _ _ ds _) = getBound ds getBound (ClassDecl _ _ _ _ mds) = getBound mds getBound (InstDecl{}) = [] getBound (DerivDecl{}) = [] getBound (InfixDecl{}) = [] getBound (DefaultDecl{}) = [] getBound (SpliceDecl{}) = [] getBound (TypeSig{}) = [] getBound (FunBind _ []) = error "getBound: FunBind []" getBound (FunBind _ (Match _ n _ _ _ : _)) = [n] getBound (FunBind _ (InfixMatch _ _ n _ _ _ : _)) = [n] getBound (PatBind _ p _ _ _) = getBound p getBound (ForImp _ _ _ _ n _) = [n] getBound (ForExp _ _ _ n _) = [n] getBound (RulePragmaDecl{}) = [] getBound (DeprPragmaDecl{}) = [] getBound (WarnPragmaDecl{}) = [] getBound (InlineSig{}) = [] getBound (SpecSig{}) = [] getBound (SpecInlineSig{}) = [] getBound (InstSig{}) = [] getBound (AnnPragma{}) = [] getBound (InlineConlikeSig{}) = [] instance (Data l) => GetBound (QualConDecl l) l where getBound (QualConDecl _ _ _ d) = getBound d instance (Data l) => GetBound (GadtDecl l) l where getBound (GadtDecl _ n _) = [n] instance (Data l) => GetBound (ConDecl l) l where getBound (ConDecl _ n _) = [n] getBound (InfixConDecl _ _ n _) = [n] getBound (RecDecl _ n fs) = n : getBound fs instance (Data l) => GetBound (FieldDecl l) l where getBound (FieldDecl _ ns _) = ns instance (Data l) => GetBound (ClassDecl l) l where getBound (ClsDecl _ d) = getBoundSign d getBound (ClsDataFam{}) = [] getBound (ClsTyFam{}) = [] getBound (ClsTyDef{}) = [] instance (Data l) => GetBound (Match l) l where getBound (Match _ n _ _ _) = [n] getBound (InfixMatch _ _ n _ _ _) = [n] instance (Data l) => GetBound (Stmt l) l where getBound e = case e of Generator _ pat _ -> getBound pat LetStmt _ bnds -> getBound bnds RecStmt _ stmts -> getBound stmts Qualifier {} -> [] instance (Data l) => GetBound (QualStmt l) l where getBound e = case e of QualStmt _ stmt -> getBound stmt _ -> [] getBoundSign :: Decl l -> [Name l] getBoundSign (TypeSig _ ns _) = ns getBoundSign _ = [] instance (Data l) => GetBound (Pat l) l where getBound p = [ n | p' <- universe $ transform dropExp p, n <- varp p' ] where varp (PVar _ n) = [n] varp (PAsPat _ n _) = [n] varp _ = [] dropExp (PViewPat _ _ x) = x -- must remove nested Exp so universe doesn't descend into them dropExp x = x isTypeDecl :: Decl l -> Bool isTypeDecl (TypeDecl _ _ _) = True isTypeDecl (TypeFamDecl _ _ _) = True isTypeDecl (DataDecl _ _ _ _ _ _) = True isTypeDecl (GDataDecl _ _ _ _ _ _ _) = True isTypeDecl (DataFamDecl _ _ _ _) = True isTypeDecl _ = False opName :: Op l -> Name l opName (VarOp _ n) = n opName (ConOp _ n) = n isCon :: Name l -> Bool isCon (Ident _ (c:_)) = isUpper c isCon (Symbol _ (':':_)) = True isCon _ = False nameToString :: Name l -> String nameToString (Ident _ s) = s nameToString (Symbol _ s) = s specialConToString :: SpecialCon l -> String specialConToString (UnitCon _) = "()" specialConToString (ListCon _) = "[]" specialConToString (FunCon _) = "->" specialConToString (TupleCon _ Boxed n) = replicate (n-1) ',' specialConToString (TupleCon _ Unboxed n) = '#':replicate (n-1) ',' specialConToString (Cons _) = ":" specialConToString (UnboxedSingleCon _) = "#" unCName :: CName l -> Name l unCName (VarName _ n) = n unCName (ConName _ n) = n getErrors :: (Ord l, Foldable a) => a (Scoped l) -> Set.Set (Error l) getErrors = foldl' f Set.empty where f errors (Scoped (ScopeError e) _) = Set.insert e errors f errors _ = errors -- | Compute the extension set for the given module, based on the global -- preferences (e.g. specified on the command line) and module's LANGUAGE -- pragmas. moduleExtensions :: Language -- ^ base language -> [Extension] -- ^ global extensions -> Module l -> ExtensionSet moduleExtensions globalLang globalExts mod = let (mbModLang, modExts) = getModuleExtensions mod lang = fromMaybe globalLang mbModLang kexts = toExtensionList lang (globalExts ++ modExts) in Set.fromList kexts getModuleExtensions :: Module l -> (Maybe Language, [Extension]) getModuleExtensions mod = let names = [ name | let pragmas = case mod of Module _ _ pragmas _ _ -> pragmas XmlPage _ _ pragmas _ _ _ _ -> pragmas XmlHybrid _ _ pragmas _ _ _ _ _ _ -> pragmas , LanguagePragma _ names <- pragmas , Ident _ name <- names ] classified :: [Either Language Extension] classified = flip map names $ \name -> case (parseExtension name, classifyLanguage name) of (e, UnknownLanguage {}) -> Right e (_, l) -> Left l (langs, exts) = partitionEithers classified in (if null langs then Nothing else Just $ last langs, exts)