module ArrSyn
( translate
) where
import ArrCode
import NewCode
import Utils
import Control.Monad.Trans.State
import Control.Monad.Trans.Writer
import Data.List (mapAccumL)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Debug.Hoed.Pure
import Language.Haskell.Exts (Alt (..), Binds (..), Decl (..),
Exp (), GuardedRhs (..),
Match (..), Name, Pat (..),
Rhs (..), Stmt (..), ann)
import qualified Language.Haskell.Exts as H
translate :: Pat S -> Exp S -> Exp S
translate p c = H.Paren (ann c) $ toHaskell (transCmd s p' c)
where (s, p') = startPattern p
startPattern :: Pat S -> (TransState, Pat S)
startPattern = observe "startPattern" startPattern'
startPattern' :: Pat S -> (TransState, Pat S)
startPattern' p =
(TransState {
locals = definedVars p,
cmdVars = Map.empty
}, p)
transCmd :: TransState -> Pat S -> Exp S -> Arrow
transCmd = observe "transCmd" transCmd'
transCmd' :: TransState -> Pat S -> Exp S -> Arrow
transCmd' s p (H.LeftArrApp _ f e)
| Set.null (freeVars f `Set.intersection` locals s) =
arr 0 (input s) p e >>> arrowExp f
| otherwise =
arr 0 (input s) p (pair f e) >>> app
transCmd' s p (H.LeftArrHighApp l f e) = transCmd s p (H.LeftArrApp l f e)
transCmd' _ _ H.RightArrApp{} =
error "assumption failed: Right arrow applications should have been desugared to left arrow applications before this point"
transCmd' _ _ H.RightArrHighApp{} =
error "assumption failed: Right arrow applications should have been desugared to left arrow applications before this point"
transCmd' s p (H.InfixApp _ c1 op c2) =
infixOp (transCmd s p c1) op (transCmd s p c2)
transCmd' s p (H.Let _ decls c) =
arrLet (anonArgs a) (input s) p decls' e >>> a
where (s', decls') = addVars' s decls
(e, a) = transTrimCmd s' c
transCmd' s p (H.If l e c1 c2)
| Set.null (freeVars e `Set.intersection` locals s) =
ifte e (transCmd s p c1) (transCmd s p c2)
| otherwise =
arr 0 (input s) p (H.If l e (left e1) (right e2)) >>> (a1 ||| a2)
where (e1, a1) = transTrimCmd s c1
(e2, a2) = transTrimCmd s c2
transCmd' s p (H.Case l e as)
| Set.null (freeVars e `Set.intersection` locals s) =
Arrow {
context = ctx,
anonArgs = 0,
code = H.Case (Loc l) (Loc <$> e) alts
}
| otherwise =
arr 0 (input s) p (H.Case l e as') >>> foldr1 (|||) (reverse cases)
where
(alts, ctx) = runWriter $ flip traverseAlts (Loc <$$> as) $ \exp -> do
let arrow = transCmd s p (getLoc <$> exp)
tell $ context arrow
return $ code arrow
(as', (ncases, cases)) = runState (mapM (transAlt s) as) (0, [])
transAlt = observeSt "transAlt" transAlt'
transAlt' s (Alt loc p gas decls) = do
let (s', p') = addVars' s p
(s'', decls') = addVars' s' decls
gas' <- transGuardedRhss s'' gas
return (H.Alt loc p' gas' decls')
transGuardedRhss = observeSt "transGuardedRhss" transGuardedRhss'
transGuardedRhss' s (UnGuardedRhs l c) = do
body <- newAlt s c
return (H.UnGuardedRhs l body)
transGuardedRhss' s (GuardedRhss l gas) = do
gas' <- mapM (transGuardedRhs s) gas
return (H.GuardedRhss l gas')
transGuardedRhs = observeSt "transGuardedRhs" transGuardedRhs'
transGuardedRhs' s (GuardedRhs loc e c) = do
body <- newAlt s c
return (H.GuardedRhs loc e body)
newAlt = observeSt "newAlt" newAlt'
newAlt' s c = do
let (e, a) =
transTrimCmd s c
(n, as) <- get
put (n + 1, a : as)
return (label n e)
label = observe "label" label'
label' n e =
times
n
right
(if n < ncases 1
then left e
else e)
transCmd' s p (H.Paren _ c) =
transCmd s p c
transCmd' s p (H.Do _ ss) =
transDo s p (init ss) (let Qualifier _ e = last ss in e)
transCmd' s p (H.App _ c arg) =
anon (1) $
arr (anonArgs a) (input s) p (pair e arg) >>> a
where (e, a) = transTrimCmd s c
transCmd' s p (H.Lambda _ ps c) =
anon (length ps) $ bind (definedVars ps) $ transCmd s' (foldl pairP p ps') c
where
(s', ps') = addVars' s ps
transCmd' _ _ x = error $ "transCmd: " ++ show x
transTrimCmd :: TransState -> Exp S -> (Exp S, Arrow)
transTrimCmd = observe "transTrimCmd" transTrimCmd'
transTrimCmd' :: TransState -> Exp S -> (Exp S, Arrow)
transTrimCmd' s c = (expTuple (context a), a)
where a = transCmd s (patternTuple (context a)) c
transDo :: TransState -> Pat S -> [Stmt S] -> Exp S -> Arrow
transDo = observe "transDo" transDo'
transDo' :: TransState -> Pat S -> [Stmt S] -> Exp S -> Arrow
transDo' s p [] c =
transCmd s p c
transDo' s p (Qualifier l exp : ss) c =
transDo' s p (Generator l (PWildCard l) exp : ss) c
transDo' s p (Generator _ pg cg:ss) c =
if isEmptyTuple u then
transCmd s p cg >>> transDo s' pg ss c
else
arr 0 (input s) p (pair eg (expTuple u)) >>> first ag u >>> a
where (s', pg') = addVars' s pg
a = observe "a" $ bind (definedVars pg)
(transDo s' (pairP pg' (patternTuple u)) ss c)
u = observe "u" $ context a
(eg, ag) = transTrimCmd s cg
transDo' s p (LetStmt l decls : ss) c =
transCmd s p (H.Let l decls (H.Do l (ss ++ [Qualifier l c])))
transDo' s p (RecStmt l rss:ss) c =
bind
defined
(loop
(transDo
s'
(pairP p (irrPat (patternTuple feedback)))
rss'
(returnCmd (pair output (expTuple feedback)))) >>>
a)
where
defined = foldMap definedVars rss
(s', rss') = addVars' s rss
(output, a) = transTrimCmd s' (H.Do l (ss ++ [Qualifier l c]))
feedback =
context
(transDo
s'
p
rss'
(returnCmd
(foldr (pair . H.Var l . H.UnQual l) output (Set.toList defined)))) `intersectTuple`
defined
data TransState = TransState {
locals :: Set (Name S),
cmdVars :: Map (Name S) Arrow
} deriving (Eq, Generic, Show)
instance Observable TransState
input :: TransState -> Tuple
input s = Tuple (locals s)
addVars'
:: (Observable a, AddVars a, Eq l, Show l, l ~ LocType a)
=> TransState -> a -> (TransState, a)
addVars' = observe "addVars" addVars
class AddVars a where
addVars :: TransState -> a -> (TransState, a)
instance AddVars a => AddVars [a] where
addVars = mapAccumL addVars
instance AddVars a => AddVars (Maybe a) where
addVars = mapAccumL addVars
instance AddVars (Pat S) where
addVars s p =
(s {locals = locals s `Set.union` definedVars p}, p)
instance AddVars (Decl S) where
addVars s d@(FunBind l (Match _ n _ _ _:_)) =
(s', d)
where (s', _) = addVars s (PVar l n)
addVars s (PatBind loc p rhs decls) =
(s', PatBind loc p' rhs decls)
where (s', p') = addVars s p
addVars s d = (s, d)
instance AddVars (Stmt S) where
addVars s it@Qualifier{} = (s, it)
addVars s (Generator loc p c) =
(s', Generator loc p' c)
where (s', p') = addVars s p
addVars s (LetStmt l decls) =
(s', LetStmt l decls')
where (s', decls') = addVars s decls
addVars s (RecStmt l stmts) =
(s', RecStmt l stmts')
where (s', stmts') = addVars s stmts
instance AddVars (Binds S) where
addVars s (BDecls l decls) = BDecls l <$> addVars s decls
addVars s it@IPBinds{} = (s, it)