-- | Compiler IR Compiler (CIRC): A language for specifying compiler intermediate representations. module Language.CIRC ( -- * CIRC Specifications Spec (..) , Transform (..) , Type (..) , TypeDef (..) , CtorDef (..) , TypeRefinement (..) , Name , ModuleName , CtorName , TypeName , TypeParam , Code , Import , t , indent -- * CIRC Compilation , circ ) where import Control.Monad import Data.Function import Data.List import System.Directory import System.IO import Text.Printf -- | A specification is a module name for the initial type, common imports, the root type, the initial type definitions, and a list of transforms. data Spec = Spec Name [Import] TypeName [TypeDef] [Transform] type Name = String type ModuleName = String type CtorName = String type TypeName = String type TypeParam = String type Code = String type Import = String -- | A type expression. data Type = T TypeName [Type] | TList Type | TMaybe Type | TTuple [Type] -- | A type definition is a name, a list of type parameters, and a list of constructor definitions. data TypeDef = TypeDef TypeName [TypeParam] [CtorDef] -- | A constructor definition is a name and a list of type arguments. data CtorDef = CtorDef CtorName [Type] -- | A type refinement. data TypeRefinement = NewCtor TypeName CtorDef (ModuleName -> Code) | NewType TypeDef -- | A transform is a module name, the constructor to be transformed, a list of new type definitions, -- and the implementation (imports and code). data Transform = Transform ModuleName [Import] [Import] [(CtorName, ModuleName -> Code)] [TypeRefinement] -- | An unparameterized type. t :: String -> Type t n = T n [] -- | Compiles a CIRC spec. circ :: Spec -> IO () circ (Spec initModuleName initImports rootTypeName typeDefsUnsorted transforms) = do maybeWriteFile (initModuleName ++ ".hs") $ codeTypeModule initModuleName initImports typeDefs maybeWriteFile (initModuleName ++ "Trans.hs") $ codeInitTransModule initModuleName rootTypeName foldM_ codeTransform (initModuleName, typeDefs) transforms where typeDefs = sortTypeDefs typeDefsUnsorted codeTransform :: (Name, [TypeDef]) -> Transform -> IO (Name, [TypeDef]) codeTransform (prevModuleName, prevTypeDefs) (Transform moduleName typeImports transImports removedCtors typeRefinements) = do maybeWriteFile (moduleName ++ ".hs") $ codeTypeModule moduleName typeImports typeDefs maybeWriteFile (moduleName ++ "Trans.hs") $ codeTransModule initModuleName rootTypeName prevModuleName prevTypeDefs moduleName transImports typeDefs [ (name, code prevModuleName) | (name, code) <- removedCtors ] [ (ctorName, transCode prevModuleName) | NewCtor _ (CtorDef ctorName _) transCode <- typeRefinements ] return (moduleName, typeDefs) where filteredCtor = [ TypeDef name params [ CtorDef ctorName args | CtorDef ctorName args <- ctors, notElem ctorName $ fst $ unzip removedCtors ] | TypeDef name params ctors <- prevTypeDefs ] typeDefs = sortTypeDefs $ filterRelevantTypes rootTypeName $ nextTypes filteredCtor typeRefinements -- | Write out a file if the file doesn't exist or is different. Doesn't bump the timestamp for Makefile-like build systems. maybeWriteFile :: FilePath -> String -> IO () maybeWriteFile file contents = do a <- doesFileExist file if not a then writeFile file contents else do f <- openFile file ReadMode contents' <- hGetContents f if contents' == contents then do hClose f return () else do hClose f writeFile file contents -- | Sort a list of TypeDefs by type name. sortTypeDefs :: [TypeDef] -> [TypeDef] sortTypeDefs = sortBy (compare `on` \ (TypeDef n _ _) -> n) -- | Code the module that contains the IR datatype definitions. codeTypeModule :: ModuleName -> [Import] -> [TypeDef] -> String codeTypeModule moduleName imports typeDefs = unlines $ [ printf "module %s" moduleName , " ( " ++ intercalate "\n , " [ name ++ " (..)"| TypeDef name _ _ <- typeDefs ] , " ) where" , "" ] ++ nub (["import Language.CIRC.Runtime"] ++ imports) ++ [""] ++ map codeTypeDef typeDefs where codeTypeDef :: TypeDef -> String codeTypeDef (TypeDef name params ctors) = "data " ++ name ++ " " ++ intercalate " " params ++ "\n = " ++ intercalate "\n | " [ name ++ replicate (m - length name) ' ' ++ " " ++ intercalate " " (map codeType args) | CtorDef name args <- ctors' ] ++ "\n" where ctors' = sortBy (compare `on` \ (CtorDef n _) -> n) ctors m = maximum [ length n | CtorDef n _ <- ctors ] codeType :: Type -> String codeType a = case a of T name [] -> name T name params -> "(" ++ name ++ intercalate " " (map codeType params) ++ ")" TList a -> "[" ++ codeType a ++ "]" TMaybe a -> "(Maybe " ++ codeType a ++ ")" TTuple a -> "(" ++ intercalate ", " (map codeType a) ++ ")" -- | Code the initial transform module. codeInitTransModule :: ModuleName -> TypeName -> String codeInitTransModule moduleName rootTypeName = unlines [ printf "module %sTrans" moduleName , printf " ( transform" , printf " , transform'" , printf " ) where" , printf "" , printf "import Language.CIRC.Runtime (CIRC)" , printf "import %s (%s)" moduleName rootTypeName , printf "" , printf "transform :: %s -> CIRC (%s, [%s])" rootTypeName rootTypeName rootTypeName , printf "transform a = return (a, [a])" , printf "" , printf "transform' :: %s -> CIRC %s" rootTypeName rootTypeName , printf "transform' = return" , printf "" ] -- | Code the module that contains the IR transformations. codeTransModule :: ModuleName -> TypeName -> ModuleName -> [TypeDef] -> ModuleName -> [Import] -> [TypeDef] -> [(CtorName, Code)] -> [(CtorName, Code)] -> String codeTransModule initModuleName rootTypeName prevModuleName prevTypeDefs moduleName imports typeDefs transCode backwardTransCode = unlines $ [ printf "module %sTrans" moduleName , " ( transform" , " , transform'" , " ) where" , "" ] ++ nub ( [ "import Language.CIRC.Runtime" , "import qualified " ++ initModuleName , "import qualified " ++ prevModuleName , "import qualified " ++ prevModuleName ++ "Trans" , "import " ++ moduleName ] ++ imports) ++ [ printf "" , printf "transform :: %s.%s -> CIRC (%s, [%s.%s])" initModuleName rootTypeName rootTypeName initModuleName rootTypeName , printf "transform a = do" , printf " (a, b) <- %sTrans.transform a" prevModuleName , printf " a <- trans%s a" rootTypeName , printf " c <- transform' a" , printf " return (a, b ++ [c])" , printf "" , printf "transform' :: %s -> CIRC %s.%s" rootTypeName initModuleName rootTypeName , printf "transform' a = trans%s' a >>= %sTrans.transform'" rootTypeName prevModuleName , printf "" , codeTypeTransforms prevModuleName prevTypeDefs typeDefs transCode backwardTransCode , printf "" ] -- | Codes the type transform function. codeTypeTransforms :: ModuleName -> [TypeDef] -> [TypeDef] -> [(CtorName, Code)] -> [(CtorName, Code)] -> String codeTypeTransforms prevName prevTypes currTypes forwardTrans backwardTrans = concatMap (codeTypeTransform prevTypes forwardTrans (\ t -> "trans" ++ t) qualified id) [ t | t@(TypeDef n _ _) <- prevTypes, elem n $ map typeDefName currTypes ] ++ concatMap (codeTypeTransform currTypes backwardTrans (\ t -> "trans" ++ t ++ "'") id qualified) [ t | t@(TypeDef n _ _) <- currTypes, elem n $ map typeDefName prevTypes ] where typeDefName (TypeDef n _ _) = n qualified :: String -> String qualified a = prevName ++ "." ++ a vars = map (: []) ['a' .. 'z'] codeTypeTransform :: [TypeDef] -> [(CtorName, Code)] -> (TypeName -> String) -> (CtorName -> String) -> (CtorName -> String) -> TypeDef -> String codeTypeTransform fromTypes transforms transName from to (TypeDef typeName _params ctors) = unlines $ -- XXX What do we do with type params? [ transName typeName ++ " :: " ++ from typeName ++ " -> CIRC " ++ to typeName , transName typeName ++ " a = case a of" , indent $ unlines $ map codeCtor ctors ] where codeCtor :: CtorDef -> String codeCtor (CtorDef ctorName ctorArgs) = case lookup ctorName transforms of Nothing -> from ctorName ++ args ++ " -> do { " ++ impArgs ctorArgs ++ "return $ " ++ to ctorName ++ args ++ " }" Just code -> "\n{- Transform Begin -}\n" ++ (from $ drop 2 $ indent code) ++ "{- Transform End -}\n" where args = concat [ ' ' : v | v <- take (length ctorArgs) vars ] impArgs :: [Type] -> Code impArgs types = concatMap wrapArg $ zip vars types wrapArg :: (Name, Type) -> Code wrapArg (var, typ) = printf "%s <- %s %s; " var (codeArg typ) var codeArg :: Type -> Code codeArg typ = case typ of t | not $ any (flip elem [ name | TypeDef name _ _ <- fromTypes ]) $ primitiveTypes t -> "return" T t _ -> transName t TList t -> printf "mapM (%s)" $ codeArg t TMaybe t -> printf "(\\ a -> case a of { Nothing -> return Nothing; Just a -> do { a <- %s a; return $ Just $ a } })" $ codeArg t TTuple ts -> printf "(\\ (%s) -> do { %sreturn (%s) })" args (impArgs ts) args where args = intercalate ", " $ take (length ts) vars -- | Returns a list of names of all primitive types used in a type. primitiveTypes :: Type -> [TypeName] primitiveTypes a = case a of T n _ -> [n] TList t -> primitiveTypes t TMaybe t -> primitiveTypes t TTuple ts -> concatMap primitiveTypes ts -- | Indents code with 2 spaces. indent :: String -> String indent = unlines . map (" " ++) . lines -- | Computes the next type definitions given a list of type definitions and a list of type refinements. nextTypes :: [TypeDef] -> [TypeRefinement] -> [TypeDef] nextTypes old new = foldl nextType old new where nextType :: [TypeDef] -> TypeRefinement -> [TypeDef] nextType types refinement = case refinement of NewType t -> t : types NewCtor typeName ctorDef _ -> case match of [] -> error $ "Type not found: " ++ typeName _ : _ : _ -> error $ "Redundent type name: " ++ typeName [TypeDef _ params ctors] -> TypeDef typeName params (ctorDef : ctors) : rest where (match, rest) = partition (\ (TypeDef name _ _) -> name == typeName) types -- | Get rid of types that are not relevant to the root type. filterRelevantTypes :: TypeName -> [TypeDef] -> [TypeDef] filterRelevantTypes rootTypeName types = [ t | t@(TypeDef n _ _) <- types, elem n required ] where typeDeps :: TypeName -> [TypeName] typeDeps name = nub $ concat [ concat [ concatMap primitiveTypes t | CtorDef _ t <- ctors ] | TypeDef n _ ctors <- types, n == name ] required = next ([], [rootTypeName]) next :: ([TypeName], [TypeName]) -> [TypeName] next (sofar, remaining) = case remaining of [] -> sofar a : rest | elem a sofar -> next (sofar, rest) | otherwise -> next (a : sofar, rest ++ typeDeps a)