{-# LANGUAGE GeneralizedNewtypeDeriving, DeriveDataTypeable, PatternGuards, TupleSections, ViewPatterns #-}

-- | Module for defining and manipulating expressions.
module Proof.Exp.Core(
    Var(..), Con(..), Exp(..), Pat(..),
    fromApps, fromLams, fromLets, lets, lams, apps,
    isVar,
    vars, varsP, free, subst, relabel, relabelAvoid, fresh,
    equivalent,
    fromExp, fromName,
    simplifyExp
    ) where

import Data.Maybe
import Data.List
import Data.Data
import Control.Monad
import Control.Monad.Trans.State
import Data.Char
import Control.Arrow
import Language.Haskell.Exts hiding (Exp,Name,Pat,Var,Let,App,Case,Con,name,parse,Pretty)
import qualified Language.Haskell.Exts as H
import Proof.Exp.HSE
import Control.DeepSeq
import Proof.Util hiding (fresh)
import Data.Generics.Uniplate.Data
import Control.Applicative
import Prelude


---------------------------------------------------------------------
-- TYPE

newtype Var = V {fromVar :: String} deriving (Data,Typeable,Eq,Show,Ord,NFData)
newtype Con = C {fromCon :: String} deriving (Data,Typeable,Eq,Show,Ord,NFData)

data Exp
    = Var Var
    | Con Con
    | App Exp Exp
    | Let Var Exp Exp -- non-recursive
    | Lam Var Exp
    | Case Exp [(Pat,Exp)]
      deriving (Data,Typeable,Eq,Ord)

data Pat
    = PCon Con [Var]
    | PWild
      deriving (Data,Typeable,Eq,Ord)

instance Read Exp where
    readsPrec = simpleReadsPrec $ fromExp . deflate . fromParseResult . parseExp

instance Show Exp where
    show = prettyPrint . unparen . inflate . toExp
        where unparen (Paren x) = x
              unparen x = x



isVar Var{} = True; isVar _ = False

instance NFData Exp where
    rnf (Var a) = rnf a
    rnf (Con a) = rnf a
    rnf (App a b) = rnf2 a b
    rnf (Let a b c) = rnf3 a b c
    rnf (Lam a b) = rnf2 a b
    rnf (Case a b) = rnf2 a b

instance NFData Pat where
    rnf (PCon a b) = rnf2 a b
    rnf PWild = ()

caseCon :: Exp -> Maybe ([(Var,Exp)], Exp)
caseCon o@(Case (fromApps -> (Con c, xs)) alts) = Just $ headNote (error $ "Malformed case: " ++ show o) $ mapMaybe f alts
    where f (PWild, x) = Just ([], x)
          f (PCon c2 vs, x) | c /= c2 = Nothing
                            | length vs /= length xs = error "Malformed arity"
                            | otherwise = Just (zip vs xs, x)
caseCon _ = Nothing

apps x (y:ys) = apps (App x y) ys
apps x [] = x

lams (y:ys) x = Lam y $ lams ys x
lams [] x = x

lets [] x = x
lets ((a,b):ys) x = Let a b $ lets ys x


fromLets (Let x y z) = ((x,y):a, b)
    where (a,b) = fromLets z
fromLets x = ([], x)

fromLams (Lam x y) = (x:a, b)
    where (a,b) = fromLams y
fromLams x = ([], x)

fromApps (App x y) = (a,b ++ [y])
    where (a,b) = fromApps x
fromApps x = (x,[])

---------------------------------------------------------------------
-- BINDING AWARE OPERATIONS

vars :: Exp -> [Var]
vars = universeBi

varsP :: Pat -> [Var]
varsP = universeBi

free :: Exp -> [Var]
free (Var x) = [x]
free (App x y) = nub $ free x ++ free y
free (Lam x y) = delete x $ free y
free (Case x y) = nub $ free x ++ concat [free b \\ varsP a | (a,b) <- y]
free (Let a b y) = nub $ free b ++ delete a (free y)
free _ = []


subst :: [(Var,Exp)] -> Exp -> Exp
subst [] x = x
subst ren e = case e of
    Var x -> fromMaybe (Var x) $ lookup x ren
    App x y -> App (f [] x) (f [] y)
    Lam x y -> Lam x (f [x] y)
    Case x y -> Case (f [] x) [(a, f (varsP a) b) | (a,b) <- y]
    Let a b y -> Let a (f [] b) $ f [a] y
    x -> x
    where f del x = subst (filter (flip notElem del . fst) ren) x


relabel :: Exp -> Exp
relabel x = relabelAvoid (free x) x

relabelAvoid :: [Var] -> Exp -> Exp
relabelAvoid xs x = evalState (f [] x) (fresh xs)
    where
        f :: [(Var,Var)] -> Exp -> State [Var] Exp
        f mp (Lam v x) = do i <- var; Lam i <$> f ((v,i):mp) x
        f mp (Let v x y) = do i <- var; Let i <$> f mp x <*> f ((v,i):mp) y
        f mp (Case x alts) = Case <$> f mp x <*> mapM (g mp) alts
        f mp (App x y) = App <$> f mp x <*> f mp y
        f mp (Var x) = return $ Var $ fromMaybe x $ lookup x mp
        f mp x = return x

        g mp (PWild, x) = (PWild,) <$> f mp x
        g mp (PCon c vs, x) = do is <- replicateM (length vs) var; (PCon c is,) <$> f (zip vs is ++ mp) x

        var = do s:ss <- get; put ss; return s

fresh :: [Var] -> [Var]
fresh used = map V (concatMap f [1..]) \\ used
    where f 1 = map return ['a'..'z']
          f i = [a ++ b | a <- f 1, b <- f (i-1)]


eval :: Exp -> Exp
eval = relabel . nf . relabel
    where
        whnf (Let v x y) = whnf $ subst [(v,x)] y
        whnf (App (whnf -> Lam v x) y) = whnf $ subst [(v,y)] x
        whnf (App (whnf -> Case x alts) y) = whnf $ Case x $ map (second $ flip App y) alts
        whnf (Case (whnf -> x) alts) | Just (bs, bod) <- caseCon $ Case x alts = whnf $ subst bs bod
        whnf (Case (whnf -> Case x alts1) alts2) = Case x [(a, Case b alts2) | (a,b) <- alts1]
        whnf x = x

        nf = descend nf . whnf


equivalent :: String -> Exp -> Exp -> Exp
equivalent = equivalentOn eval


---------------------------------------------------------------------
-- SIMPLIFY

simplifyExp :: Exp -> Exp
simplifyExp = \(relabel -> x) -> equivalent "simplify" x $ idempotent "simplify" fs x
    where
        fs = transform f

        f o@(App (fromLets -> (bs@(_:_), Lam v z)) q) = fs $ Let v q $ lets bs z
        f o@(Case (Let v x y) alts) = fs $ Let v x $ Case y alts
        {-
        -- True, but a bit different to the others, since it is information propagation
        -- Nothing requries it yet
        f o@(Case (Var v) alts) | map g alts /= alts = fs $ Case (Var v) $ map g alts
            where g (PCon c vs, x) | v `notElem` vs = (PCon c vs, subst [(v, apps (Con c) $ map Var vs)] x)
                  g x = x
        -}
        f (App (Lam v x) y) = f $ Let v y x
        f (Let v x y) | cheap x || linear v y = fs $ subst [(v,x)] y
        f o@(Case (Case on alts1) alts2) =  fs $ Case on $ map g alts1
            where g (PWild, c) = (PWild, Case c alts2)
                  g (PCon a vs, c) = (PCon a vs, Case c alts2)
        f x | Just ((unzip -> (vs, xs)), bod) <- caseCon x = fs $ lets (zip vs xs) bod
        f x = x

cheap (Var _) = True
cheap (Con _) = True
cheap (Lam _ _) = True
cheap _ = False


linear :: Var -> Exp -> Bool
linear v x = count v x <= 1

count :: Var -> Exp -> Int
count v (Var x) = if v == x then 1 else 0
count v (Lam w y) = if v == w then 0 else count v y * 2 -- lambda count is infinite, but 2 is close enough
count v (Let w x y) = count v x + (if v == w then 0 else count v y)
count v (Case x alts) = count v x + maximum [if v `elem` varsP p then 0 else count v c | (p,c) <- alts]
count v (App x y) = count v x + count v y
count v _ = 0


---------------------------------------------------------------------
-- FROM HSE

fromDecl :: Decl -> [(Var,Exp)]
fromDecl (PatBind _ (PVar f) (UnGuardedRhs x) (BDecls [])) = [(V $ fromName f, fromExp x)]
fromDecl TypeSig{} = []
fromDecl DataDecl{} = []
fromDecl TypeDecl{} = []
fromDecl x = error $ "Unhandled fromDecl: " ++ show x

fromExp :: H.Exp -> Exp
fromExp (Lambda _ [PVar (Ident x)] bod) = Lam (V x) $ fromExp bod
fromExp (H.App x y) = App (fromExp x) (fromExp y)
fromExp (H.Var (UnQual x)) = Var $ V $ fromName x
fromExp (H.Con (UnQual x)) = Con $ C $ fromName x
fromExp (Paren x) = fromExp x
fromExp (H.Case x xs) = Case (fromExp x) $ map fromAlt xs
fromExp (H.Let (BDecls [d]) x) | [(a,b)] <- fromDecl d =  Let a b $ fromExp x
fromExp x = error $ "Unhandled fromExp: " ++ show x

fromName :: H.Name -> String
fromName (Ident x) = x
fromName (Symbol x) = x

fromAlt :: Alt -> (Pat, Exp)
fromAlt (Alt _ pat (UnGuardedRhs bod) (BDecls [])) = (fromPat pat, fromExp bod)
fromAlt x = error $ "Unhandled fromAlt: " ++ show x

fromPat :: H.Pat -> Pat
fromPat (PParen x) = fromPat x
fromPat (PApp (UnQual c) xs) = PCon (C $ fromName c) $ map (V . fromPatVar) xs
fromPat PWildCard = PWild
fromPat x = error $ "Unhandled fromPat: " ++ show x

fromPatVar :: H.Pat -> String
fromPatVar (PVar x) = fromName x
fromPatVar x = error $ "Unhandled fromPatVar: " ++ show x


---------------------------------------------------------------------
-- TO HSE

toDecl :: Var -> Exp -> Decl
toDecl (V f) x = PatBind sl (PVar $ toName f) (UnGuardedRhs $ toExp x) (BDecls [])

toExp :: Exp -> H.Exp
toExp (Var (V x)) = H.Var $ UnQual $ toName x
toExp (Con (C x)) = H.Con $ UnQual $ toName x
toExp (Lam (V x) y) = Lambda sl [PVar $ toName x] $ toExp y
toExp (Let a b y) = H.Let (BDecls [toDecl a b]) $ toExp y
toExp (App x y) = H.App (toExp x) (toExp y)
toExp (Case x y) = H.Case (toExp x) (map toAlt y)

toAlt :: (Pat, Exp) -> Alt
toAlt (x,y) = Alt sl (toPat x) (UnGuardedRhs $ toExp y) (BDecls [])

toPat :: Pat -> H.Pat
toPat (PCon (C c) vs) = PApp (UnQual $ toName c) (map (PVar . Ident . fromVar) vs)
toPat PWild = PWildCard

toName :: String -> H.Name
toName xs@(x:_) | isAlphaNum x || x `elem` "'_(" = Ident xs
                | otherwise = Symbol xs