{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} -- Undo pointfree transformations. Plugin code derived from Pl.hs. module Lambdabot.Pointful (pointful) where import Lambdabot.Parser (withParsed, prettyPrintInLine) import Control.Monad.Reader import Control.Monad.State import Data.Functor.Identity (Identity) import Data.Generics import qualified Data.Set as S import qualified Data.Map as M import Data.List import Data.Maybe import Language.Haskell.Exts.Simple as Hs ---- Utilities ---- stabilize :: Eq a => (a -> a) -> a -> a stabilize f x = let x' = f x in if x' == x then x else stabilize f x' -- varsBoundHere returns variables bound by top patterns or binders varsBoundHere :: Data d => d -> S.Set Name varsBoundHere (cast -> Just (PVar name)) = S.singleton name varsBoundHere (cast -> Just (Match name _ _ _)) = S.singleton name varsBoundHere (cast -> Just (PatBind pat _ _)) = varsBoundHere pat varsBoundHere (cast -> Just (_ :: Exp)) = S.empty varsBoundHere d = S.unions (gmapQ varsBoundHere d) -- note: the tempting idea of using a pattern synonym for the frequent -- (cast -> Just _) patterns causes compiler crashes with ghc before -- version 8; cf. https://ghc.haskell.org/trac/ghc/ticket/11336 foldFreeVars :: forall a d. Data d => (Name -> S.Set Name -> a) -> ([a] -> a) -> d -> a foldFreeVars var sum e = runReader (go e) S.empty where go :: forall d. Data d => d -> Reader (S.Set Name) a go (cast -> Just (Var (UnQual name))) = asks (var name) go (cast -> Just (Lambda ps exp)) = bind [varsBoundHere ps] $ go exp go (cast -> Just (Let bs exp)) = bind [varsBoundHere bs] $ collect [go bs, go exp] go (cast -> Just (Alt pat exp bs)) = bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs] go (cast -> Just (PatBind pat exp bs)) = bind [varsBoundHere pat, varsBoundHere bs] $ collect [go exp, go bs] go (cast -> Just (Match _ ps exp bs)) = bind [varsBoundHere ps, varsBoundHere bs] $ collect [go exp, go bs] go d = collect (gmapQ go d) collect :: forall m. Monad m => [m a] -> m a collect ms = sum `liftM` sequence ms bind :: forall a b. Ord a => [S.Set a] -> Reader (S.Set a) b -> Reader (S.Set a) b bind ss = local (S.unions ss `S.union`) -- return free variables freeVars :: Data d => d -> S.Set Name freeVars = foldFreeVars (\name bv -> S.singleton name `S.difference` bv) S.unions -- return number of free occurrences of a variable countOcc :: Data d => Name -> d -> Int countOcc name = foldFreeVars var sum where sum = foldl' (+) 0 var name' bv = if name /= name' || name' `S.member` bv then 0 else 1 -- variable capture avoiding substitution substAvoiding :: Data d => M.Map Name Exp -> S.Set Name -> d -> d substAvoiding subst bv = base `extT` exp `extT` alt `extT` decl `extT` match where base :: Data d => d -> d base = gmapT (substAvoiding subst bv) exp e@(Var (UnQual name)) = fromMaybe e (M.lookup name subst) exp (Lambda ps exp) = let (subst', bv', ps') = renameBinds subst bv ps in Lambda ps' (substAvoiding subst' bv' exp) exp (Let bs exp) = let (subst', bv', bs') = renameBinds subst bv bs in Let (substAvoiding subst' bv' bs') (substAvoiding subst' bv' exp) exp d = base d alt (Alt pat exp bs) = let (subst1, bv1, pat') = renameBinds subst bv pat (subst', bv', bs') = renameBinds subst1 bv1 bs in Alt pat' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs') decl (PatBind pat exp bs) = let (subst', bv', bs') = renameBinds subst bv bs in PatBind pat (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs') decl d = base d match (Match name ps exp bs) = let (subst1, bv1, ps') = renameBinds subst bv ps (subst', bv', bs') = renameBinds subst1 bv1 bs in Match name ps' (substAvoiding subst' bv' exp) (substAvoiding subst' bv' bs') -- rename local binders (but not the nested expressions) renameBinds :: Data d => M.Map Name Exp -> S.Set Name -> d -> (M.Map Name Exp, S.Set Name, d) renameBinds subst bv d = (subst', bv', d') where (d', (subst', bv', _)) = runState (go d) (subst, bv, M.empty) go, base :: Data d => d -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) d go = base `extM` pat `extM` match `extM` decl `extM` exp base d = gmapM go d pat (PVar name) = PVar `fmap` rename name pat d = base d match (Match name ps exp bs) = do name' <- rename name return $ Match name' ps exp bs decl (PatBind pat exp bs) = do pat' <- go pat return $ PatBind pat' exp bs decl d = base d exp (e :: Exp) = return e rename :: Name -> State (M.Map Name Exp, S.Set Name, M.Map Name Name) Name rename name = do (subst, bv, ass) <- get case (name `M.lookup` ass, name `S.member` bv) of (Just name', _) -> do return name' (_, False) -> do put (M.delete name subst, S.insert name bv, ass) return name _ -> do let name' = freshNameAvoiding name bv put (M.insert name (Var (UnQual name')) subst, S.insert name' bv, M.insert name name' ass) return name' -- generate fresh names freshNameAvoiding :: Name -> S.Set Name -> Name freshNameAvoiding name forbidden = con (pre ++ suf) where (con, nm, cs) = case name of Ident n -> (Ident, n, "0123456789") Symbol n -> (Symbol, n, "?#") pre = reverse . dropWhile (`elem` cs) . reverse $ nm sufs = [1..] >>= flip replicateM cs suf = head $ dropWhile (\suf -> con (pre ++ suf) `S.member` forbidden) sufs ---- Optimization (removing explicit lambdas) and restoration of infix ops ---- -- move lambda patterns into LHS optimizeD :: Decl -> Decl optimizeD (PatBind (PVar fname) (UnGuardedRhs (Lambda pats rhs)) Nothing) = let (subst, bv, pats') = renameBinds M.empty (S.singleton fname) pats rhs' = substAvoiding subst bv rhs in FunBind [Match fname pats' (UnGuardedRhs rhs') Nothing] ---- combine function binding and lambda optimizeD (FunBind [Match fname pats1 (UnGuardedRhs (Lambda pats2 rhs)) Nothing]) = let (subst, bv, pats2') = renameBinds M.empty (varsBoundHere pats1) pats2 rhs' = substAvoiding subst bv rhs in FunBind [Match fname (pats1 ++ pats2') (UnGuardedRhs rhs') Nothing] optimizeD x = x -- remove parens optimizeRhs :: Rhs -> Rhs optimizeRhs (UnGuardedRhs (Paren x)) = UnGuardedRhs x optimizeRhs x = x optimizeE :: Exp -> Exp -- apply ((\x z -> ...x...) y) yielding (\z -> ...y...) if there is only one x or y is simple optimizeE (App (Lambda (PVar ident : pats) body) arg) | single || simple arg = let (subst, bv, pats') = renameBinds (M.singleton ident arg) (freeVars arg) pats in Paren (Lambda pats' (substAvoiding subst bv body)) where single = countOcc ident body <= 1 simple e = case e of Var _ -> True; Lit _ -> True; Paren e' -> simple e'; _ -> False -- apply ((\_ z -> ...) y) yielding (\z -> ...) optimizeE (App (Lambda (PWildCard : pats) body) _) = Paren (Lambda pats body) -- remove 0-arg lambdas resulting from application rules optimizeE (Lambda [] b) = b -- replace (\x -> \y -> z) with (\x y -> z) optimizeE (Lambda p1 (Lambda p2 body)) = let (subst, bv, p2') = renameBinds M.empty (varsBoundHere p1) p2 body' = substAvoiding subst bv body in Lambda (p1 ++ p2') body' -- remove double parens optimizeE (Paren (Paren x)) = Paren x -- remove parens around applied lambdas (the pretty printer restores them) optimizeE (App (Paren (x@Lambda{})) y) = App x y -- remove lambda body parens optimizeE (Lambda p (Paren x)) = Lambda p x -- remove var, lit parens optimizeE (Paren x@(Var _)) = x optimizeE (Paren x@(Lit _)) = x -- remove infix+lambda parens optimizeE (InfixApp a o (Paren l@(Lambda _ _))) = InfixApp a o l -- remove infix+app aprens optimizeE (InfixApp (Paren a@App{}) o l) = InfixApp a o l optimizeE (InfixApp a o (Paren l@App{})) = InfixApp a o l -- remove left-assoc application parens optimizeE (App (Paren (App a b)) c) = App (App a b) c -- restore infix optimizeE (App (App (Var name'@(UnQual (Symbol _))) l) r) = (InfixApp l (QVarOp name') r) -- eta reduce optimizeE (Lambda ps@(_:_) (App e (Var (UnQual v)))) | free && last ps == PVar v = Lambda (init ps) e where free = countOcc v e == 0 -- fail optimizeE x = x ---- Decombinatorization ---- uncomb' :: Exp -> Exp uncomb' (Paren (Paren e)) = Paren e -- eliminate sections uncomb' (RightSection op' arg) = let a = freshNameAvoiding (Ident "a") (freeVars arg) in (Paren (Lambda [PVar a] (InfixApp (Var (UnQual a)) op' arg))) uncomb' (LeftSection arg op') = let a = freshNameAvoiding (Ident "a") (freeVars arg) in (Paren (Lambda [PVar a] (InfixApp arg op' (Var (UnQual a))))) -- infix to prefix for canonicality uncomb' (InfixApp lf (QVarOp name') rf) = (Paren (App (App (Var name') (Paren lf)) (Paren rf))) -- Expand (>>=) when it is obviously the reader monad: -- rewrite: (>>=) (\x -> e) -- to: (\ a b -> a ((\ x -> e) b) b) uncomb' (App (Var (UnQual (Symbol ">>="))) (Paren lam@Lambda{})) = let a = freshNameAvoiding (Ident "a") (freeVars lam) b = freshNameAvoiding (Ident "b") (freeVars lam) in (Paren (Lambda [PVar a, PVar b] (App (App (Var (UnQual a)) (Paren (App lam (Var (UnQual b))))) (Var (UnQual b))))) -- rewrite: ((>>=) e1) (\x y -> e2) -- to: (\a -> (\x y -> e2) (e1 a) a) uncomb' (App (App (Var (UnQual (Symbol ">>="))) e1) (Paren lam@(Lambda (_:_:_) _))) = let a = freshNameAvoiding (Ident "a") (freeVars [e1,lam]) in (Paren (Lambda [PVar a] (App (App lam (App e1 (Var (UnQual a)))) (Var (UnQual a))))) -- fail uncomb' expr = expr ---- Simple combinator definitions --- combinators :: M.Map Name Exp combinators = M.fromList $ map declToTuple defs where defs = case parseModule combinatorModule of ParseOk (Hs.Module _ _ _ d) -> d f@(ParseFailed _ _) -> error ("Combinator loading: " ++ show f) declToTuple (PatBind (PVar fname) (UnGuardedRhs body) Nothing) = (fname, Paren body) declToTuple _ = error "Pointful Plugin error: can't convert declaration to tuple" combinatorModule :: String combinatorModule = unlines [ "(.) = \\f g x -> f (g x) ", "($) = \\f x -> f x ", "flip = \\f x y -> f y x ", "const = \\x _ -> x ", "id = \\x -> x ", "(=<<) = flip (>>=) ", "liftM2 = \\f m1 m2 -> m1 >>= \\x1 -> m2 >>= \\x2 -> return (f x1 x2) ", "join = (>>= id) ", "ap = liftM2 id ", "(>=>) = flip (<=<) ", "(<=<) = \\f g x -> f >>= g x ", " ", "-- ASSUMED reader monad ", "-- (>>=) = (\\f k r -> k (f r) r) ", "-- return = const ", ""] ---- Top level ---- unfoldCombinators :: (Data a) => a -> a unfoldCombinators = substAvoiding combinators (freeVars combinators) uncombOnce :: (Data a) => a -> a uncombOnce x = everywhere (mkT uncomb') x uncomb :: (Eq a, Data a) => a -> a uncomb = stabilize uncombOnce optimizeOnce :: (Data a) => a -> a optimizeOnce x = everywhere (mkT optimizeD `extT` optimizeRhs `extT` optimizeE) x optimize :: (Eq a, Data a) => a -> a optimize = stabilize optimizeOnce pointful :: String -> String pointful = withParsed (stabilize (optimize . uncomb) . stabilize (unfoldCombinators . uncomb)) -- TODO: merge this into a proper test suite once one exists -- test s = case parseModule s of -- f@(ParseFailed _ _) -> fail (show f) -- ParseOk (Hs.Module _ _ _ _ _ _ defs) -> -- flip mapM_ defs $ \def -> do -- putStrLn . prettyPrintInLine $ def -- putStrLn . prettyPrintInLine . uncomb $ def -- putStrLn . prettyPrintInLine . optimize . uncomb $ def -- putStrLn . prettyPrintInLine . stabilize (optimize . uncomb) $ def -- putStrLn "" -- -- main = test "f = tail . head; g = head . tail; h = tail + tail; three = g . h . i; dontSub = (\\x -> x + x) 1; ofHead f = f . head; fm = flip mapM_ xs (\\x -> g x); po = (+1); op = (1+); g = (. f); stabilize = fix (ap . flip (ap . (flip =<< (if' .) . (==))) =<<)" --