{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-name-shadowing #-} module ArrSyn ( translate ) where import ArrCode import NewCode import Utils import Control.Monad.Trans.State import Control.Monad.Trans.Writer import Data.Default import Data.List (mapAccumL) import Data.Map (Map) import qualified Data.Map as Map import Data.Set (Set) import qualified Data.Set as Set import Debug.Hoed.Pure import Language.Haskell.Exts (Alt (..), Binds (..), Decl (..), Exp (), GuardedRhs (..), Match (..), Name, Pat (..), Rhs (..), Stmt (..), ann) import qualified Language.Haskell.Exts as H -- ----------------------------------------------------------------------------- -- Translation to Haskell -- This is a 2-phase process: -- - transCmd' generates an abstract arrow combinator language represented -- by the Arrow type, and -- - toHaskell turns that into Haskell. translate :: Pat S -> Exp S -> Exp S translate p c = H.Paren (ann c) $ toHaskell (transCmd s p' c) where (s, p') = startPattern p startPattern :: Pat S -> (TransState, Pat S) startPattern = observe "startPattern" startPattern' startPattern' :: Pat S -> (TransState, Pat S) startPattern' p = (TransState { locals = definedVars p, cmdVars = Map.empty }, p) -- The pattern argument is often pseudo-recursively defined in terms of -- the context part of the result of these functions. (It's not real -- recursion, because that part is independent of the pattern.) transCmd :: TransState -> Pat S -> Exp S -> Arrow transCmd = observe "transCmd" transCmd' transCmd' :: TransState -> Pat S -> Exp S -> Arrow transCmd' s p (H.LeftArrApp _ f e) | Set.null (freeVars f `Set.intersection` locals s) = arr 0 (input s) p e >>> arrowExp f | otherwise = arr 0 (input s) p (pair f e) >>> app transCmd' s p (H.LeftArrHighApp l f e) = transCmd s p (H.LeftArrApp l f e) transCmd' _ _ H.RightArrApp{} = error "assumption failed: Right arrow applications should have been desugared to left arrow applications before this point" transCmd' _ _ H.RightArrHighApp{} = error "assumption failed: Right arrow applications should have been desugared to left arrow applications before this point" transCmd' s p (H.InfixApp _ c1 op c2) = infixOp (transCmd s p c1) op (transCmd s p c2) transCmd' s p (H.Let _ decls c) = arrLet (anonArgs a) (input s) p decls' e >>> a where (s', decls') = addVars' s decls (e, a) = transTrimCmd s' c transCmd' s p (H.If l e c1 c2) | Set.null (freeVars e `Set.intersection` locals s) = ifte e (transCmd s p c1) (transCmd s p c2) | otherwise = arr 0 (input s) p (H.If l e (left e1) (right e2)) >>> (a1 ||| a2) where (e1, a1) = transTrimCmd s c1 (e2, a2) = transTrimCmd s c2 transCmd' s p (H.Case l e as) | Set.null (freeVars e `Set.intersection` locals s) = Arrow { context = ctx, anonArgs = 0, code = H.Case (Loc l) (Loc <$> e) alts } | otherwise = arr 0 (input s) p (H.Case l e as') >>> foldr1 (|||) (reverse cases) where (alts, ctx) = runWriter $ flip traverseAlts (Loc <$$> as) $ \exp -> do let arrow = transCmd s p (getLoc <$> exp) tell $ context arrow return $ code arrow (as', (ncases, cases)) = runState (mapM (transAlt s) as) (0, []) transAlt = observeSt "transAlt" transAlt' transAlt' s (Alt loc p gas decls) = do let (s', p') = addVars' s p (s'', decls') = addVars' s' decls gas' <- transGuardedRhss s'' gas return (H.Alt loc p' gas' decls') transGuardedRhss = observeSt "transGuardedRhss" transGuardedRhss' transGuardedRhss' s (UnGuardedRhs l c) = do body <- newAlt s c return (H.UnGuardedRhs l body) transGuardedRhss' s (GuardedRhss l gas) = do gas' <- mapM (transGuardedRhs s) gas return (H.GuardedRhss l gas') transGuardedRhs = observeSt "transGuardedRhs" transGuardedRhs' transGuardedRhs' s (GuardedRhs loc e c) = do body <- newAlt s c return (H.GuardedRhs loc e body) newAlt = observeSt "newAlt" newAlt' newAlt' s c = do let (e, a) = transTrimCmd s c (n, as) <- get put (n + 1, a : as) return (label n e) label = observe "label" label' label' n e = times n right (if n < ncases - 1 then left e else e) transCmd' s p (H.Paren _ c) = transCmd s p c transCmd' s p (H.Do _ ss) = transDo s p (init ss) (let Qualifier _ e = last ss in e) transCmd' s p (H.App _ c arg) = anon (-1) $ arr (anonArgs a) (input s) p (pair e arg) >>> a where (e, a) = transTrimCmd s c transCmd' s p (H.Lambda _ ps c) = anon (length ps) $ bind (definedVars ps) $ transCmd s' (foldl pairP p ps') c where (s', ps') = addVars' s ps transCmd' _ _ x = error $ "Invalid parse: " ++ showSrcLoc (H.fromSrcInfo $ getSrcSpanInfo $ ann x) where showSrcLoc :: H.SrcLoc -> String showSrcLoc (H.SrcLoc file line col) = file ++ ":" ++ show line ++ ":" ++ show col -- transCmd' s p (CmdVar n) = -- arr (anonArgs a) (input s) p e >>> arrowExp (H.Var () (H.UnQual () n)) -- where Just a = Map.lookup n (cmdVars s) -- e = expTuple (context a) -- Like TransCmd, but use the minimal input pattern. The first component -- of the result is the matching expression to build this input. -- That is, the result is (e, proc p' -> c) with the minimal p' such that -- proc p -> c = arr (first^n (p -> e)) >>> (proc p' -> c) -- where n is the number of anonymous arguments taken by c. transTrimCmd :: TransState -> Exp S -> (Exp S, Arrow) transTrimCmd = observe "transTrimCmd" transTrimCmd' transTrimCmd' :: TransState -> Exp S -> (Exp S, Arrow) transTrimCmd' s c = (expTuple (context a), a) where a = transCmd s (patternTuple (context a)) c transDo :: TransState -> Pat S -> [Stmt S] -> Exp S -> Arrow transDo = observe "transDo" transDo' transDo' :: TransState -> Pat S -> [Stmt S] -> Exp S -> Arrow transDo' s p [] c = transCmd s p c transDo' s p (Qualifier l exp : ss) c = transDo' s p (Generator l (PWildCard l) exp : ss) c transDo' s p (Generator _ pg cg:ss) c = if isEmptyTuple u then transCmd s p cg >>> transDo s' pg ss c else arr 0 (input s) p (pair eg (expTuple u)) >>> first ag u >>> a where (s', pg') = addVars' s pg a = observe "a" $ bind (definedVars pg) (transDo s' (pairP pg' (patternTuple u)) ss c) u = observe "u" $ context a (eg, ag) = transTrimCmd s cg transDo' s p (LetStmt l decls : ss) c = transCmd s p (H.Let l decls (H.Do l (ss ++ [Qualifier l c]))) transDo' s p (RecStmt l rss:ss) c = bind defined (loop (transDo s' (pairP p (irrPat (patternTuple feedback))) rss' (returnCmd (pair output (expTuple feedback)))) >>> a) where defined = foldMap definedVars rss (s', rss') = addVars' s rss (output, a) = transTrimCmd s' (H.Do l (ss ++ [Qualifier l c])) feedback = context (transDo s' p rss' (returnCmd (foldr (pair . H.Var l . H.UnQual l) output (const def <$$> Set.toList defined)))) `intersectTuple` defined data TransState = TransState { locals :: Set (Name ()), -- vars in scope defined in this proc cmdVars :: Map (Name ()) Arrow } deriving (Eq, Generic, Show) instance Observable TransState input :: TransState -> Tuple input s = Tuple (locals s) addVars' :: (Observable a, AddVars a) => TransState -> a -> (TransState, a) addVars' = observe "addVars" addVars class AddVars a where addVars :: TransState -> a -> (TransState, a) instance AddVars a => AddVars [a] where addVars = mapAccumL addVars instance AddVars a => AddVars (Maybe a) where addVars = mapAccumL addVars instance AddVars (Pat S) where addVars s p = (s {locals = locals s `Set.union` definedVars p}, p) instance AddVars (Decl S) where addVars s d@(FunBind l (Match _ n _ _ _:_)) = (s', d) where (s', _) = addVars s (PVar l n) addVars s (PatBind loc p rhs decls) = (s', PatBind loc p' rhs decls) where (s', p') = addVars s p addVars s d = (s, d) instance AddVars (Stmt S) where addVars s it@Qualifier{} = (s, it) addVars s (Generator loc p c) = (s', Generator loc p' c) where (s', p') = addVars s p addVars s (LetStmt l decls) = (s', LetStmt l decls') where (s', decls') = addVars s decls addVars s (RecStmt l stmts) = (s', RecStmt l stmts') where (s', stmts') = addVars s stmts instance AddVars (Binds S) where addVars s (BDecls l decls) = BDecls l <$> addVars s decls addVars s it@IPBinds{} = (s, it)