module Rewrite where import Term ( Term(Node, Number, StringLiteral), range, name, termRange ) import TermFocus ( TermFocus(TermFocus), SuperTerm ) import Program ( Program ) import SourceText ( ModuleRange ) import qualified Program import qualified Module import qualified TermFocus import qualified Term import qualified Rule import qualified Control.Monad.Trans.Writer as MW import qualified Control.Monad.Trans.RWS as MRWS import qualified Control.Monad.Trans.Class as MT import Control.Monad.Trans.RWS ( RWS, asks, tell, get, put ) import Control.Monad.Exception.Synchronous ( Exceptional(Exception,Success), ExceptionalT, mapExceptionalT, throwT, assertT ) import qualified Data.Map as Map import qualified Data.Traversable as Trav import Data.Map ( Map ) import Data.Monoid ( Monoid ) import Data.Maybe.HT ( toMaybe ) import Data.Tuple.HT ( mapSnd ) import Data.List ( intercalate ) import Data.Eq.HT ( equating ) -- import Debug.Trace ( trace ) type RTerm = Term.Term ModuleRange type Identifier = Term.Identifier ModuleRange data Message = Term { term :: TermFocus } | Source { source :: Source } deriving Show data Source = Step { target :: Identifier } | AttemptRule { rule :: Identifier } | Rule { rule :: Identifier } | Data { origin :: Identifier } deriving Show data Context = Context { maxReductions :: Count, program :: Program, superTerms :: [ SuperTerm ] } type Count = Int type Evaluator = ExceptionalT (ModuleRange, String) ( RWS Context [ Message ] Count ) runEval :: (Monad m) => Count -> Program -> Evaluator a -> ExceptionalT (ModuleRange, String) ( MW.WriterT [ Message ] m ) a runEval maxRed p = -- in transformers-0.3 you can write MW.writer instead of MW.WriterT . return mapExceptionalT (\evl -> MW.WriterT $ return $ MRWS.evalRWS evl (Context {maxReductions = maxRed, program = p, superTerms = []}) 0) {- mapExceptionalT (\evl -> MW.WriterT $ return $ (\(a,s,w) -> trace (show s) (a,w)) $ MRWS.runRWS evl (maxRed,p) 0) -} exception :: ModuleRange -> String -> Evaluator a exception rng msg = throwT (rng, msg) -- | force head of stream: -- evaluate until we have Cons or Nil at root, -- then evaluate first argument of Cons fully. forceHead :: RTerm -> Evaluator RTerm forceHead t = do t' <- top t case t' of Node i [ x, xs ] | name i == ":" -> do y <- localSuperTerm i [] [xs] $ full x return $ Node i [ y, xs ] Node i [] | name i == "[]" -> return $ Node i [] _ -> exception (termRange t') $ "not a list term: " ++ show t -- | force full evaluation -- (result has only constructors and numbers) full :: RTerm -> Evaluator RTerm full x = do x' <- top x case x' of Node f args -> fmap (Node f) $ mapArgs f full args Number _ _ -> return x' StringLiteral _ _ -> return x' -- | evaluate until root symbol is constructor. top :: RTerm -> Evaluator RTerm top t = ( MT.lift $ tell . (:[]) . Term . TermFocus t =<< asks superTerms ) >> case t of Number {} -> return t StringLiteral {} -> return t Node f xs -> if Term.isConstructor f then return t else eval f xs >>= top mapArgs :: Identifier -> (RTerm -> Evaluator RTerm) -> [RTerm] -> Evaluator [RTerm] mapArgs i f = let go _ [] = return [] go done (x:xs) = do y <- localSuperTerm i done xs $ f x fmap (y:) $ go (y:done) xs in go [] localSuperTerm :: (Monad m, Monoid w) => Identifier -> [RTerm] -> [RTerm] -> ExceptionalT e (MRWS.RWST Context w s m) b -> ExceptionalT e (MRWS.RWST Context w s m) b localSuperTerm i done xs = mapExceptionalT (MRWS.local (\ctx -> ctx{superTerms = TermFocus.Node i (TermFocus.List done xs) : superTerms ctx})) -- | do one reduction step at the root eval :: Identifier -> [RTerm] -> Evaluator RTerm eval i xs | name i `elem` [ "compare", "<", "-", "+", "*", "div", "mod" ] = do ys <- mapArgs i top xs MT.lift $ tell $ [ Source $ Step { target = i } ] case ys of [ Number _ a, Number _ b] -> case name i of -- FIXME: handling of positions is dubious "<" -> return $ Node ( Term.Identifier { name = show (a < b) , range = range i } ) [] "compare" -> return $ Node ( Term.Identifier { name = show (compare a b) , range = range i } ) [] "-" -> return $ Number (range i) $ a - b "+" -> return $ Number (range i) $ a + b "*" -> return $ Number (range i) $ a * b "div" -> return $ Number (range i) $ div a b "mod" -> return $ Number (range i) $ mod a b opName -> exception (range i) $ "unknown operation " ++ show opName _ -> exception (range i) $ "wrong number of arguments" eval g ys = do funcs <- MT.lift $ asks ( Program.functions . program ) case Map.lookup (Module.stripIdentifier g) funcs of Nothing -> exception (range g) $ unwords [ "unknown function", show $ Node g ys ] Just (_name, rules) -> evalDecls g rules ys evalDecls :: Identifier -> [ Rule.Rule ModuleRange ] -> [RTerm] -> Evaluator RTerm evalDecls g = foldr (\(Rule.Rule f xs rhs) go ys -> do MT.lift $ tell [ Source $ AttemptRule f ] (m, ys') <- matchExpandList Map.empty g [] xs ys case m of Nothing -> go ys' Just (substitions, additionalArgs) -> do conss <- MT.lift $ asks ( Program.constructors . program ) MT.lift $ tell $ map Source $ Step g : Rule f : ( map Data $ Map.elems $ Map.intersectionWith const conss $ Map.fromList $ map (flip (,) ()) $ map Module.stripIdentifier $ foldr constructors [] xs ) rhs' <- apply substitions rhs appendArguments rhs' additionalArgs) (\ys -> exception (range g) $ unwords [ "no matching pattern for function", show g, "and arguments", show ys ]) constructors :: RTerm -> [Identifier] -> [Identifier] constructors (Node f xs) acc = if Term.isConstructor f then f : foldr constructors acc xs else acc constructors _ acc = acc appendArguments :: RTerm -> [RTerm] -> Evaluator RTerm appendArguments f xs = case Term.appendArguments f xs of Success t -> return t Exception e -> exception (termRange f) e -- | check whether term matches pattern. -- do some reductions if they are necessary to decide about the match. -- return the reduced term in the second result component. matchExpand :: RTerm -> RTerm -> Evaluator ( Maybe (Map Module.Identifier RTerm) , RTerm ) matchExpand pat t = case pat of Node f [] | Term.isVariable f -> return ( Just $ Map.singleton (Module.stripIdentifier f) t , t ) Node f xs | Term.isConstructor f -> do t' <- top t case t' of Node g ys -> if equating name f g then do ( m, ys' ) <- matchExpandList Map.empty g [] xs ys return ( fmap fst m, Node f ys' ) else return ( Nothing, t' ) _ -> exception (termRange t') $ "constructor pattern matched against non-constructor term: " ++ show t' Node _ _ -> exception (termRange pat) $ "pattern is neither constructor nor number: " ++ show pat Number _ a -> do t' <- top t case t' of Number _ b -> return ( toMaybe (a==b) Map.empty, t' ) _ -> exception (termRange t') $ "number pattern matched against non-number term: " ++ show t' StringLiteral _ a -> do t' <- top t case t' of StringLiteral _ b -> return ( toMaybe (a==b) Map.empty, t' ) _ -> exception (termRange t') $ "string pattern matched against non-string term: " ++ show t' matchExpandList :: Map Module.Identifier RTerm -> Identifier -> [RTerm] -> [RTerm] -> [RTerm] -> Evaluator (Maybe (Map Module.Identifier RTerm, [RTerm]), [RTerm]) matchExpandList s _ _ [] ys = return ( Just (s,ys), ys ) matchExpandList s i done (x:xs) (y:ys) = do (m, y') <- localSuperTerm i done ys $ matchExpand x y fmap (mapSnd (y':)) $ case m of Nothing -> return ( Nothing, ys ) Just s' -> do s'' <- case MW.runWriter $ Trav.sequenceA $ Map.unionWithKey (\var t _ -> MW.tell [var] >> t) (fmap return s) (fmap return s') of (un, []) -> return un (_, vars) -> exception (termRange y') $ "variables bound more than once in pattern: " ++ intercalate ", " (map Module.deconsIdentifier vars) matchExpandList s'' i (y':done) xs ys matchExpandList _ _ _ (x:_) _ = exception (termRange x) "too few arguments" apply :: Map Module.Identifier RTerm -> RTerm -> Evaluator RTerm apply m t = checkMaxReductions (termRange t) >> case t of Node f xs -> do ys <- mapM ( apply m ) xs case Map.lookup (Module.stripIdentifier f) m of Nothing -> return $ Node f ys Just t' -> appendArguments t' ys _ -> return t checkMaxReductions :: ModuleRange -> Evaluator () checkMaxReductions rng = do maxCount <- MT.lift $ asks maxReductions count <- MT.lift get assertT (rng, "number of reductions exceeds limit " ++ show maxCount) $ count < maxCount MT.lift $ put $ succ count