module Main where import System.IO import System.Console.GetOpt import System.Environment import Language.Haskell.Exts.Syntax as Exts import Language.Haskell.Exts.Parser import Language.Haskell.Exts.Pretty import Data.Maybe import Data.List import PwPf import Matching import FunctorOf import Hylos import Language.Pointfree.Pretty import Language.Pointfree.Syntax as Pf import Language.Pointwise.Syntax as Pw import Language.Pointwise.Pretty import Language.Pointwise.Parser import Language.Pointwise.Matching import Generics.Pointless.Combinators import Control.Monad.State import Data.Generics.Schemes import Data.Generics.Aliases -- Managing Options data Flag = Input String | Output String | Fixify | Pointwise | Observable deriving Eq options :: [OptDescr Flag] options = [Option ['o'] ["output"] (OptArg outp "FILE") "output FILE", Option ['i'] ["input"] (OptArg inp "FILE") "input FILE", Option ['f'] ["fix"] (NoArg Fixify) "use fixpoints instead of hylomorphisms", Option ['w'] ["pointwise"] (NoArg Pointwise) "do not convert to point-free", Option ['O'] ["observable"] (NoArg Observable) "generate observable hylomorphisms" ] inp,outp :: Maybe String -> Flag outp = Output . fromMaybe "stdout" inp = Input . fromMaybe "stdin" parseOpts :: [String] -> IO [Flag] parseOpts opts = case (getOpt Permute options opts) of (l,[],[]) -> return l (_,_,errs) -> fail (concat errs ++"\n"++ usageInfo header options) where header = "DrHylo derives point-free hylomorphisms from restricted Haskell syntax\n\nUsage: DrHylo [OPTION...]" getInput :: [Flag] -> IO Handle getInput [] = return stdin getInput ((Input i):_) | i=="stdin" = return stdin | otherwise = openFile i ReadMode getInput (_:l) = getInput l getOutput :: [Flag] -> IO Handle getOutput [] = return stdout getOutput ((Output i):_) | i=="stdout" = return stdout | otherwise = openFile i WriteMode getOutput (_:l) = getOutput l fixrequired :: [Flag] -> Bool fixrequired = elem Fixify pwrequired :: [Flag] -> Bool pwrequired = elem Pointwise obrequired :: [Flag] -> Bool obrequired = elem Observable -- Parsing parse :: String -> IO Module parse s = case (parseModule s) of ParseOk m -> return m ParseFailed l d -> fail ((show l)++": "++d) -- Generation of Observable function contexts isTypeSig :: String -> Decl -> Bool isTypeSig name (TypeSig _ x _) = elem (Ident name) x isTypeSig _ _ = False getTypeVars :: Exts.Type -> [Name] getTypeVars = everything (++) ([] `mkQ` getVar) where getVar :: Exts.Type -> [Name] getVar (TyVar v) = [v] getVar _ = [] addTypeSig :: Decl -> Decl addTypeSig (TypeSig loc names t) = TypeSig loc names (aux t) where aux (TyForall mb ctx (TyFun a b)) = TyForall mb (ctx++inst typeable a b++inst observable a b) (TyFun a b) aux (TyFun a b) = TyForall Nothing (inst typeable a b++inst observable a b) (TyFun a b) vars a b = nub $ intersect (getTypeVars a) (getTypeVars b) inst cl a b = map (mkInsVar cl) (vars a b) mkInsVar :: Name -> Name -> Asst mkInsVar cl n = ClassA (UnQual cl) [TyVar n] addTypeableObservableIns :: String -> [Decl] -> [Decl] addTypeableObservableIns n [] = [] addTypeableObservableIns n (d:ds) | isTypeSig n d = addTypeSig d : addTypeableObservableIns n ds | otherwise = d : addTypeableObservableIns n ds -- From Pointwise to Point-free (or not) pwpfModule :: [Flag] -> [(String,Pw.Term)] -> Module -> Module pwpfModule f c (Module loc name pragmas warnings exports imports decls) = Module loc name pragmas' warnings exports imports decls'' where (decls',obs) = (id >< catMaybes) $ unzip $ map aux decls decls'' = if (obrequired f) then foldr addTypeableObservableIns decls' obs else decls' pragmaNames = if (obrequired f) then ["TypeFamilies,","DeriveDataTypeable"] else ["TypeFamilies"] pragmas' = LanguagePragma loc (map Ident pragmaNames) : pragmas aux d = case pwpfDecl f c d of Just (d',mb) -> (d',mb) Nothing -> (d,Nothing) consts :: [(String,Pw.Term)] consts = [("[]", In (Inl Unit)),(":", Lam "h" (Lam "t" (In (Inr (Pw.Var "h" :&: Pw.Var "t")))))] pwpfDecl :: [Flag] -> [(String,Pw.Term)] -> Decl -> Maybe (Decl,Maybe String) pwpfDecl f d (PatBind loc (PVar (Ident name)) (UnGuardedRhs rhs) (BDecls [])) = do pw <- hs2pw rhs pw0 <- return (step (replace (d++consts) pw)) pw1 <- evalStateT (nomatch pw0) 0 pw2 <- return (if (name `elem` free pw1) then Pw.Fix (Lam name pw1) else pw1) pw3 <- return (subst (map (\v -> (v, Pw.Const v)) (free pw2)) pw2) (rhs',ob) <- return (if (pwrequired f) then (pw2hs pw3,Nothing) else if (not (fixrequired f)) && (derivable pw3) then let (Pw.Fix (Lam nam (Lam x z))) = pw3 t = fun z nam a = Lam "__" (alg z nam (Pw.Var "__")) c = Lam x (coa z nam) hyl = if (obrequired f) then HyloO else Hylo in (pf2hs (hyl (Pf.Fix t) (unpoint (pwpf [] a)) (unpoint (pwpf [] c))),Just name) else (pf2hs (unpoint (pwpf [] pw3)),Nothing)) return (PatBind loc (PVar (Ident name)) (UnGuardedRhs rhs') (BDecls []),ob) pwpfDecl _ _ _ = fail "The transformation must be applied to simple declarations" -- Handle imports loc0 :: SrcLoc loc0 = SrcLoc "" 0 0 mkImportDecl :: String -> ImportDecl mkImportDecl n = ImportDecl loc0 (ModuleName n) False False Nothing Nothing getImportName :: ImportDecl -> String getImportName (ImportDecl _ (ModuleName n) _ _ _ _) = n handleImports :: Bool -> Module -> Module handleImports b (Module loc name pragmas warnings exports imports decls) = let aux True = ["Generics.Pointless.Combinators", "Generics.Pointless.Functors", "Generics.Pointless.RecursionPatterns", "Data.Typeable", "Debug.Observe", "Generics.Pointless.Observe.Functors", "Generics.Pointless.Observe.RecursionPatterns"] aux False = ["Generics.Pointless.Combinators", "Generics.Pointless.Functors", "Generics.Pointless.RecursionPatterns"] aux' = aux b \\ (map getImportName imports) imports' = imports++(map mkImportDecl aux') in Module loc name pragmas warnings exports imports' decls -- Main main :: IO () main = do opts <- getArgs flags <- parseOpts opts let ob = obrequired flags ihandle <- getInput flags ohandle <- getOutput flags source <- hGetContents ihandle hsModule <- parse source hsModule0 <- return (casificate hsModule) hsModule1 <- return (functorOfInst ob hsModule0) hsModule2 <- return (pwpfModule flags (getCtx hsModule1) hsModule1) hPutStrLn ohandle (prettyPrint (handleImports ob hsModule2)) hClose ihandle hClose ohandle