{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
module ArrSyn
( translate
) where
import ArrCode
import NewCode
import Utils
import Control.Monad.Trans.State
import Control.Monad.Trans.Writer
import Data.Default
import Data.List (mapAccumL)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Debug.Hoed.Pure
import Debug.Hoed.Pure.TH
import Language.Haskell.Exts (Alt (..), Binds (..), Decl (..),
Exp (), GuardedRhs (..),
Match (..), Name, Pat (..),
Rhs (..), Stmt (..), ann)
import qualified Language.Haskell.Exts as H
data TransState = TransState {
locals :: Set (Name ()),
cmdVars :: Map (Name ()) Arrow
} deriving (Eq, Generic, Show)
instance Observable TransState
obs [d|
startPattern :: Pat S -> (TransState, Pat S)
startPattern p =
(TransState {
locals = definedVars p,
cmdVars = Map.empty
}, p)
transCmd :: TransState -> Pat S -> Exp S -> Arrow
transCmd s p (H.LeftArrApp _ f e)
| Set.null (freeVars f `Set.intersection` locals s) =
arr 0 (input s) p' e >>> arrowExp f
| otherwise =
arr 0 (input s) p' (pair f e) >>> app
where
p' = trimPat e p
transCmd s p (H.LeftArrHighApp l f e) = transCmd s p (H.LeftArrApp l f e)
transCmd _ _ H.RightArrApp{} =
error "assumption failed: Right arrow applications should have been desugared to left arrow applications before this point"
transCmd _ _ H.RightArrHighApp{} =
error "assumption failed: Right arrow applications should have been desugared to left arrow applications before this point"
transCmd s p (H.InfixApp _ c1 op c2) =
infixOp (transCmd s p c1) op (transCmd s p c2)
transCmd s p (H.Let _ decls c) =
arrLet (anonArgs a) (input s) (trimPat e p) decls' e >>> a
where (s', decls') = addVars s decls
(e, a) = transTrimCmd s' c
transCmd s p (H.If l e c1 c2)
| Set.null (freeVars e `Set.intersection` locals s) =
ifte e (transCmd s p c1) (transCmd s p c2)
| otherwise =
arr 0 (input s) p (H.If l e (left e1) (right e2)) >>> (a1 ||| a2)
where (e1, a1) = transTrimCmd s c1
(e2, a2) = transTrimCmd s c2
transCmd s p (H.Case l e as)
| Set.null (freeVars e `Set.intersection` locals s) =
Arrow {
context = ctx,
anonArgs = 0,
code = H.Case (Loc l) (Loc <$> e) alts
}
| otherwise =
arr 0 (input s) p (H.Case l e as') >>> foldr1 (|||) (reverse cases)
where
(alts, ctx) = runWriter $ flip traverseAlts (Loc <$$> as) $ \exp -> do
let arrow = transCmd s p (getLoc <$> exp)
tell $ context arrow
return $ code arrow
(as', (ncases, cases)) = runState (mapM (transAlt s) as) (0, [])
transAlt = observeSt "transAlt" transAlt'
transAlt' s (Alt loc p gas decls) = do
let (s', p') = addVars s p
(s'', decls') = addVars s' decls
gas' <- transGuardedRhss s'' gas
return (H.Alt loc p' gas' decls')
transGuardedRhss = observeSt "transGuardedRhss" transGuardedRhss'
transGuardedRhss' s (UnGuardedRhs l c) = do
body <- newAlt s c
return (H.UnGuardedRhs l body)
transGuardedRhss' s (GuardedRhss l gas) = do
gas' <- mapM (transGuardedRhs s) gas
return (H.GuardedRhss l gas')
transGuardedRhs = observeSt "transGuardedRhs" transGuardedRhs'
transGuardedRhs' s (GuardedRhs loc e c) = do
body <- newAlt s c
return (H.GuardedRhs loc e body)
newAlt = observeSt "newAlt" newAlt'
newAlt' s c = do
let (e, a) =
transTrimCmd s c
(n, as) <- get
put (n + 1, a : as)
return (label n e)
label = observe "label" label'
label' n e =
times
n
right
(if n < ncases - 1
then left e
else e)
transCmd s p (H.Paren _ c) =
transCmd s p c
transCmd s p (H.Do _ ss) =
transDo s p (init ss) (let Qualifier _ e = last ss in e)
transCmd s p (H.App _ c arg) =
anon (-1) $
arr (anonArgs a) (input s) p (pair e arg) >>> a
where (e, a) = transTrimCmd s c
transCmd s p (H.Lambda _ ps c) =
anon (length ps) $ bind (definedVars ps) $ transCmd s' (foldl pairP p ps') c
where
(s', ps') = addVars s ps
transCmd _ _ x = error $ "Invalid parse: " ++ showSrcLoc (H.fromSrcInfo $ getSrcSpanInfo $ ann x)
where
showSrcLoc :: H.SrcLoc -> String
showSrcLoc (H.SrcLoc file line col) = file ++ ":" ++ show line ++ ":" ++ show col
transTrimCmd :: TransState -> Exp S -> (Exp S, Arrow)
transTrimCmd s c = (expTuple (context a), a)
where a = transCmd s (patternTuple (context a)) c
transDo :: TransState -> Pat S -> [Stmt S] -> Exp S -> Arrow
transDo s p [] c =
transCmd s p c
transDo s p (Qualifier l exp : ss) c =
transDo s p (Generator l (PWildCard l) exp : ss) c
transDo s p (Generator _ pg cg:ss) c =
if isEmptyTuple u then
transCmd s p cg >>> transDo s' pg ss c
else
arr 0 (input s) p (pair eg (expTuple u)) >>> first ag u >>> a
where (s', pg') = addVars s pg
a = observe "a" $ bind (definedVars pg)
(transDo s' (pairP pg' (patternTuple u)) ss c)
u = observe "u" $ context a
(eg, ag) = transTrimCmd s cg
transDo s p (LetStmt l decls : ss) c =
transCmd s p (H.Let l decls (H.Do l (ss ++ [Qualifier l c])))
transDo s p (RecStmt l rss:ss) c =
bind
defined
(loop
(transDo
s'
(pairP p (irrPat (patternTuple feedback)))
rss'
(returnCmd (pair output (expTuple feedback)))) >>>
a)
where
defined = foldMap definedVars rss
(s', rss') = addVars s rss
(output, a) = transTrimCmd s' (H.Do l (ss ++ [Qualifier l c]))
feedback =
context
(transDo
s'
p
rss'
(returnCmd
(foldr (pair . H.Var l . H.UnQual l) output (const def <$$> Set.toList defined)))) `intersectTuple`
defined
|]
translate :: Pat S -> Exp S -> Exp S
translate p c = H.Paren (ann c) $ toHaskell (transCmd s p' c)
where (s, p') = startPattern p
input :: TransState -> Tuple
input s = Tuple (locals s)
class AddVars a where
addVars :: TransState -> a -> (TransState, a)
instance AddVars a => AddVars [a] where
addVars = mapAccumL addVars
instance AddVars a => AddVars (Maybe a) where
addVars = mapAccumL addVars
instance AddVars (Pat S) where
addVars s p =
(s {locals = locals s `Set.union` definedVars p}, p)
instance AddVars (Decl S) where
addVars s d@(FunBind l (Match _ n _ _ _:_)) =
(s', d)
where (s', _) = addVars s (PVar l n)
addVars s (PatBind loc p rhs decls) =
(s', PatBind loc p' rhs decls)
where (s', p') = addVars s p
addVars s d = (s, d)
instance AddVars (Stmt S) where
addVars s it@Qualifier{} = (s, it)
addVars s (Generator loc p c) =
(s', Generator loc p' c)
where (s', p') = addVars s p
addVars s (LetStmt l decls) =
(s', LetStmt l decls')
where (s', decls') = addVars s decls
addVars s (RecStmt l stmts) =
(s', RecStmt l stmts')
where (s', stmts') = addVars s stmts
instance AddVars (Binds S) where
addVars s (BDecls l decls) = BDecls l <$> addVars s decls
addVars s it@IPBinds{} = (s, it)