-- File created: 2009-07-26 16:12:14
{-# LANGUAGE PatternGuards #-}
module Haschoo.Evaluator.Macros (mkMacro) where
import Control.Applicative ((<$>))
import Control.Arrow (first)
import Control.Monad (liftM2, when)
import Control.Monad.Error (throwError)
import Control.Monad.Loops (andM, allM, firstM, untilM)
import Control.Monad.State.Strict (State, runState, put, get, modify)
import Control.Monad.Trans (MonadIO, lift, liftIO)
import Control.Monad.Writer.Strict (WriterT, runWriterT, tell)
import Data.Array.IArray (elems)
import Data.IORef (IORef)
import Data.List (find)
import Data.Maybe (catMaybes, isNothing)
import Data.Monoid (Monoid(mempty, mappend))
import qualified Data.ListTrie.Patricia.Map.Enum as TM
import Data.ListTrie.Patricia.Map.Enum (TrieMap)
import Haschoo.Types ( Haschoo
, ScmValue
( ScmMacro, ScmIdentifier
, ScmList, ScmDottedList
, ScmVector)
, MacroCall(..)
, Context
, isAggregate, toScmVector)
import Haschoo.Utils ( fst3, initLast2Maybe
, ptrEq, eqWithM, ErrOr)
import Haschoo.Evaluator.Eval (maybeEval)
import Haschoo.Evaluator.Standard.Equivalence (scmEqual)
-- The Left case is for ellipses
newtype PatternMatches = PM (TrieMap Char (Either [ScmValue] ScmValue))
-- Right-biased union instead of the default left-biased: so we can overwrite
-- keys in a Writer with a second call to tell
instance Monoid PatternMatches where
mempty = PM mempty
mappend (PM a) (PM b) = PM (TM.union b a)
type Lits = [(String, Maybe ScmValue)]
mkMacro :: [IORef Context] -> String -> [(ScmValue, ScmValue, [String])]
-> Lits -> ScmValue
mkMacro ctx name pats ls = ScmMacro name f
where
f args = do
(matching,replacements) <-
runWriterT $ firstM (match ls args . fst3) pats
case matching of
Nothing -> throwError ("Invalid use of macro " ++ name)
Just (_,v,frees) ->
case replaceVars (map fst ls) frees replacements v of
Left s -> throwError s
Right replaced ->
return ( simplifyList replaced
, TM.fromList (zip frees (repeat ctx)))
-- Turning MCDotted to ScmDottedList here means that the list in the
-- ScmDottedList may actually be empty, which can't happen normally
match :: Lits -> MacroCall -> ScmValue -> WriterT PatternMatches Haschoo Bool
match lits (MCList xs) = match1 lits (ScmList xs)
match lits (MCDotted xs x) = match1 lits (ScmDottedList xs x)
-- Paraphrasing R5RS 4.3.2:
--
-- "... formally, an input form F matches a pattern P iff:" ...
match1 :: Lits -> ScmValue -> ScmValue -> WriterT PatternMatches Haschoo Bool
match1 lits arg (ScmIdentifier i) =
case find ((== i).fst) lits of
-- ... "P is a non-literal identifier" ...
Nothing -> do tell (PM $ TM.singleton i (Right arg))
return True
Just (_, binding) ->
case arg of
-- ... "P is a literal identifier and F is an identifier with
-- the same binding, or the two identifiers are equal and both
-- have no lexical binding" ...
x@(ScmIdentifier i2) -> do
xb <- lift $ maybeEval x
case binding of
Nothing -> return (isNothing xb && i == i2)
Just pb -> maybe (return False) (liftIO . ptrEq pb) xb
_ -> return False
match1 lits (ScmList args) (ScmList ps) =
case initLast2Maybe ps of
-- ... "P is of the form (P1 ... Pn Pm ) where is
-- the identifier ... and F is a list of at least n forms, the first n
-- of which match P1 through Pn and each remaining element matches Pm"
-- ...
Just (ps', p, ScmIdentifier "...") ->
let (xs, ys) = splitAt (length ps') args
in do m <- andM [ allMatch lits xs ps'
, allM (flip (match1 lits) p) ys]
when m (assignMatches ys p)
return m
-- ... "P is a list (P1 ... Pn) and F is a list of n forms that match
-- P1 through Pn" ...
_ -> allMatch lits args ps
where
assignMatches ys (ScmIdentifier i) = tell (PM $ TM.singleton i (Left ys))
-- E.g. pl = (x y z) and ys = [(1 2 3),(4 5 6)]
--
-- Bind x to [1,4], y to [2,5], z to [3,6]
assignMatches ys p =
case p of
ScmList _ -> assignList (\(ScmList l) -> l)
ScmVector _ -> assignList (\(ScmVector v) -> elems v)
ScmDottedList _ _ -> assignList (\(ScmDottedList l x) -> x:l)
_ -> return ()
where
assignList toList =
let -- Continuing the above example, this would give
-- [(x,1),(x,4),(y,2),(y,5),(z,3),(z,6)]
lpat = toList p
paired = concatMap (zip lpat . toList) ys
-- From then on it's quite trivial.
--
-- Just be sure to add an initial empty match for each identifier.
-- This ensures that patterns such as ((x) ...), when matched
-- against (), result in any x in the output being replaced with
-- nothing, instead of them not being seen as matches.
matched = [(i, []) | ScmIdentifier i <- lpat]
++ [(i, [x]) | (ScmIdentifier i, x) <- paired]
in tell . PM . TM.map Left $ TM.fromListWith (flip (++)) matched
-- ... "P is an improper list (P1 ... Pn . Pm)" ...
match1 lits args (ScmDottedList ps p) =
case args of
-- ... "and F is a list or improper list of n or more forms that match
-- P1 through Pn and whose nth cdr matches Pm" ...
ScmList as -> let (xs, ys) = splitAt (length ps) as
in andM [ allMatch lits xs ps
, match1 lits (ScmList ys) p]
ScmDottedList as a -> andM [allMatch lits as ps, match1 lits a p]
_ -> return False
-- ... "P is a vector of the form #(P1 ... Pn) and F is a vector of n forms
-- that match P1 through Pn" ...
--
-- ... "P is of the form #(P1 ... Pn Pm ) where is the
-- identifier ... and F is a vector of n or more forms, the first n of which
-- match P1 through Pn and each remaining element matches Pm" ...
--
-- Just forward to the list one, the rules are identical
match1 lits (ScmVector args) (ScmVector ps) =
match1 lits (ScmList $ elems args) (ScmList $ elems ps)
-- ... "P is a datum and F is equal to P in the sense of the equal? procedure".
match1 _ arg p = liftIO $ scmEqual arg p
allMatch :: Lits -> [ScmValue] -> [ScmValue]
-> WriterT PatternMatches Haschoo Bool
allMatch = eqWithM . match1
replaceVars :: [String] -> [String] -> PatternMatches
-> ScmValue -> ErrOr [ScmValue]
replaceVars _ _ (PM replacements) (ScmIdentifier i)
| Just r <- TM.lookup i replacements = case r of
Right r' -> return [r']
Left rs -> return rs
| i == "..." = return []
replaceVars lits frees rs (ScmList l) =
(:[]) . ScmList <$> replaceInAggregate lits frees rs l
replaceVars lits frees rs (ScmVector v) =
(:[]) . toScmVector <$> replaceInAggregate lits frees rs (elems v)
replaceVars lits frees rs (ScmDottedList l x) =
(:[]) <$> liftM2 ScmDottedList
(replaceInAggregate lits frees rs l)
(simplifyList <$> replaceVars lits frees rs x)
replaceVars _ _ _ v = return [v]
replaceInAggregate :: [String] -> [String] -> PatternMatches
-> [ScmValue] -> ErrOr [ScmValue]
replaceInAggregate lits frees rs (ag : ScmIdentifier "..." : vs)
| isAggregate ag =
liftM2 (++) (replaceInEllipticAggregate lits frees rs ag)
(replaceInAggregate lits frees rs vs)
replaceInAggregate lits frees rs (v:vs) =
liftM2 (++) (replaceVars lits frees rs v)
(replaceInAggregate lits frees rs vs)
replaceInAggregate _ _ _ [] = return []
-- "Pattern variables that occur in subpatterns followed by one or more
-- instances of the identifier ... are allowed only in subtemplates that are
-- followed by as many instances of ....
--
-- They [the subtemplates] are replaced in the output by all of the subforms
-- they [the subpatterns] match in the input, distributed as indicated [in the
-- template, by the positions of the pattern variables]."
--
-- With square-bracketed stuff added to make some sense of the last sentence
-- there.
--
-- E.g. pattern (_ (#(a b) ...)) and corresponding template ((a b) ...).
--
-- Pattern variables a and b are in a subpattern followed by one ellipsis. We
-- also have a subtemplate followed by as many ellipses, and a and b are found
-- nowhere else, so it's valid. Calling this with (_ (#(1 2) #(3 4))) we get
-- the subtemplate '(a b) ...' replaced by '(1 2) (3 4)'.
replaceInEllipticAggregate :: [String] -> [String] -> PatternMatches
-> ScmValue -> ErrOr [ScmValue]
replaceInEllipticAggregate lits frees pms val =
let (replaced, (replacementsLeft,_)) =
flip runState (TM.empty, pms) $
(go val `untilM`
do rsLeft <- fst <$> get
return (TM.null rsLeft || anyT (==0) rsLeft))
in if allT (==0) replacementsLeft
then return$ catMaybes replaced
else -- For example:
-- Pattern has (a ...) (b ...)
-- Template has ((a b) ...)
-- (a ...) was (1 2) and (b ...) was (3)
-- The first (1 3) is fine, but the next (2 ??) isn't
throwError "Inconsistent match counts for elliptic variables"
where
go :: ScmValue -> State (TrieMap Char Int, PatternMatches) (Maybe ScmValue)
go (ScmIdentifier i)
| i `elem` frees || i `elem` lits = return Nothing
| otherwise = do
(replacementsLeft, PM pm) <- get
let (v, pm') =
TM.updateLookup (Just . either (Left . drop 1) Right) i pm
case v of
Just (Right r') -> return (Just r')
Just (Left []) -> do modify (first $ TM.insert i 0)
return Nothing
Just (Left (r:rs)) -> do
let rsLeft' =
TM.alter' (Just . maybe (length rs) (subtract 1))
i replacementsLeft
put (rsLeft', PM pm')
return (Just r)
-- Blank match: e.g. pattern (a ...) matched ()
Nothing -> return Nothing
go (ScmList l) = fmap ScmList . sequence <$> mapM go l
go (ScmVector v) = fmap toScmVector . sequence <$> mapM go (elems v)
go (ScmDottedList l x) = do
l' <- sequence <$> mapM go l
x' <- go x
return (liftM2 ScmDottedList l' x')
go x = return (Just x)
anyT f = TM.foldr ((||).f) False
allT f = TM.foldr ((&&).f) True
simplifyList :: [ScmValue] -> ScmValue
simplifyList [v] = v
simplifyList vs = ScmList vs