module Flite.Traversals where

import Flite.Syntax
import Flite.Descend
import Control.Monad
import Data.List
import Flite.Fresh

funcs :: Prog -> [String]
funcs p = [f | Func f args rhs <- p]

onExp :: (Exp -> Exp) -> Prog -> Prog
onExp f p = [Func g args (f rhs) | Func g args rhs <- p]

onExpM :: Monad m => (Exp -> m Exp) -> Prog -> m Prog
onExpM f = mapM (\(Func g args rhs) ->
             do rhs' <- f rhs
                return (Func g args rhs'))

fromExp :: (Exp -> [a]) -> Prog -> [a]
fromExp f p = concat [f rhs | Func g args rhs <- p]

instance Descend Exp where
  descendM f (App e es) = return App `ap` f e `ap` mapM f es
  descendM f (Case e as) = return Case `ap` f e `ap` mapM g as
    where g (p, e) = return (,) `ap` return p `ap` f e
  descendM f (Let bs e) = return Let `ap` mapM g bs `ap` f e
    where g (v, e) = return (,) `ap` return v `ap` f e
  descendM f (PrimApp p es) = return (PrimApp p) `ap` mapM f es
  descendM f (Lam vs e) = return (Lam vs) `ap` f e
  descendM f e = return e

subst :: Exp -> Id -> Exp -> Exp
subst x v = sub
  where
    sub (Var w) | v == w = x
    sub (Let bs e) | v `elem` map fst bs = Let bs e
    sub (Case e as) = Case (sub e)
                           [ (p, if v `elem` patVars p then e else sub e)
                           | (p, e) <- as ]
    sub (Lam vs e) = if v `elem` vs then Lam vs e else Lam vs (sub e)
    sub e = descend sub e

substMany :: Exp -> [(Exp, Id)] -> Exp
substMany = foldr (uncurry subst)

patVars :: Pat -> [Id]
patVars (App e es) = concatMap patVars (e:es)
patVars (Var v) = [v]
patVars p = []

caseAlts :: Exp -> [[Alt]]
caseAlts (Case exp alts) = alts : caseAlts exp ++ rest
  where rest = concatMap (caseAlts . snd) alts
caseAlts e = extract caseAlts e

freeVarsExcept :: [Id] -> Exp -> [Id]
freeVarsExcept vs e = nub (freeVarsExcept' vs e)

freeVarsExcept' :: [Id] -> Exp -> [Id]
freeVarsExcept' vs e = fv vs e
  where
    fv vs (Case e as) =
      fv vs e ++ concat [fv (patVars p ++ vs) e | (p, e) <- as]
    fv vs (Let bs e) = let ws = map fst bs ++ vs
                       in  fv ws e ++ concatMap (fv ws . snd) bs
    fv vs (Var w) = [w | w `notElem` vs]
    fv vs (Lam ws e) = fv (ws ++ vs) e
    fv vs e = extract (fv vs) e

freeVars :: Exp -> [Id]
freeVars e = nub (freeVarsExcept' [] e)

varRefs :: Id -> Exp -> Int
varRefs v = length . filter (== v) . freeVarsExcept' []

calls :: Exp -> [Id]
calls (Fun f) = [f]
calls e = extract calls e

lookupFuncs :: Id -> Prog -> [Decl]
lookupFuncs f p = [Func g args rhs | Func g args rhs <- p, f == g]

freshen :: Exp -> Fresh Exp
freshen (Let bs e) =
  do let (vs, es) = unzip bs
     e' <- freshen e
     es' <- mapM freshen es
     ws <- mapM (\_ -> fresh) vs
     let s = zip (map Var ws) vs
     return $ Let (zip ws (map (flip substMany s) es'))
                  (substMany e' s)
freshen (Case e as) = return Case `ap` freshen e `ap` mapM freshenAlt as
freshen e = descendM freshen e

freshenPat :: Pat -> Fresh Pat
freshenPat (Var _) = return Var `ap` fresh
freshenPat p = descendM freshenPat p

freshenAlt :: (Pat, Exp) -> Fresh (Pat, Exp)
freshenAlt (p, e) =
  do p' <- freshenPat p
     e' <- freshen e
     let s = zip (map Var (patVars p')) (patVars p)
     return (p', substMany e' s)

freshBody :: ([Id], Exp) -> Fresh ([Id], Exp)
freshBody (vs, e) =
  do ws <- mapM (\_ -> fresh) vs
     e' <- freshen e
     let s = zip (map Var ws) vs
     return (ws, substMany e' s)