{-# LANGUAGE TemplateHaskell,
             ScopedTypeVariables,
             FlexibleInstances,
             MultiParamTypeClasses,
             FlexibleContexts,
             UndecidableInstances,
             TupleSections,
             GADTs #-}

module A where



import Unbound.LocallyNameless hiding (prec,empty,Data,Refl,Val)

import Unbound.LocallyNameless.Alpha
import Unbound.LocallyNameless.Types

import Control.Monad
import Control.Monad.Except

import Data.Monoid (Monoid(..))

import qualified Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map


import Util
import Text.PrettyPrint as PP


------------------
-- should move to Unbound.LocallyNameless.Ops
-- patUnbind :: (Alpha p, Alpha t) => p -> Bind p t -> t
-- patUnbind p (B _ t) = openT p t
------------------


-- System A

type TyName = Name Ty
type ValName = Name Val

data Flag = Un | Init
  deriving (Eq, Ord, Show)

data Ty = TyVar TyName
        | TyInt
        | All (Bind [TyName] [Ty])
        | TyProd [(Ty, Flag)]  -- new
        | Exists (Bind TyName Ty) 
   deriving Show

data Val = TmInt Int
        | TmVar ValName
        | TApp (Ann Val) Ty  
        | Pack Ty (Ann Val)  
   deriving Show       
            
data Ann v = Ann v Ty
   deriving Show
            
data Decl   = 
    DeclVar     ValName (Embed (Ann Val))
  | DeclPrj Int ValName (Embed (Ann Val))
  | DeclPrim    ValName (Embed ((Ann Val), Prim, (Ann Val)))
  | DeclUnpack  TyName ValName (Embed (Ann Val))  
  | DeclMalloc  ValName (Embed [Ty]) -- new
  | DeclAssign  ValName (Embed ((Ann Val), Int, (Ann Val))) --new
     -- x = v1 [i] <- v2
    deriving Show
             
data Tm = 
    Let (Bind Decl Tm)
  | App   (Ann Val) [(Ann Val)] 
  | TmIf0 (Ann Val) Tm Tm
  | Halt  Ty (Ann Val)    
    deriving Show

data HeapVal = 
    Tuple [(Ann Val)]
  | Code (Bind [TyName] (Bind [ValName] Tm))
    deriving Show

newtype Heap = Heap (Map ValName (Ann HeapVal)) deriving Show

instance Monoid A.Heap where
  mempty  = A.Heap Map.empty
  mappend (A.Heap h1) (A.Heap h2) = A.Heap (Map.union h1 h2)

$(derive [''HeapVal, ''Flag, ''Ty, ''Val, ''Ann, ''Decl, ''Tm])

------------------------------------------------------
instance Alpha Flag
instance Alpha Ty 
instance Alpha Val 
instance Alpha a => Alpha (Ann a)
instance Alpha Decl
instance Alpha Tm

instance Subst Ty Ty where
  isvar (TyVar x) = Just (SubstName x)
  isvar _ = Nothing
instance Subst Ty Prim
instance Subst Ty Tm
instance Subst Ty (Ann Val)
instance Subst Ty Decl
instance Subst Ty Val
instance Subst Ty Flag

instance Subst Val Flag
instance Subst Val Prim
instance Subst Val Ty
instance Subst Val (Ann Val)
instance Subst Val Decl
instance Subst Val Tm
instance Subst Val Val where
  isvar (TmVar x) = Just (SubstName x)
  isvar _  = Nothing
  
------------------------------------------------------
-- Helper functions
------------------------------------------------------

mkTyApp :: (MonadError String m, Fresh m) => (Ann Val) -> [Ty] -> m (Ann Val)
mkTyApp av [] = return av
mkTyApp av@(Ann _ (All bnd)) (ty:tys) = do
    (as, atys) <- unbind bnd               
    case as of 
      a:as' -> 
        let atys' = subst a ty atys in
        mkTyApp (Ann (TApp av ty) (All (bind as' atys'))) tys
      _ -> throwError "type error: not a polymorphic All"
mkTyApp (Ann _ ty) _ = throwError "type error: not an All"

lets :: [Decl] -> Tm -> Tm
lets [] tm = tm 
lets (d:ds) tm = Let (bind d (lets ds tm))

-----------------------------------------------------------------
-- Free variables, with types
-----------------------------------------------------------------

x :: Name Tm
y :: Name Tm
z :: Name Tm
(x,y,z) = (string2Name "x", string2Name "y", string2Name "z")

a :: Name Ty
b :: Name Ty
c :: Name Ty
(a,b,c) = (string2Name "a", string2Name "b", string2Name "c")

-----------------------------------------------------------------
-- Typechecker
-----------------------------------------------------------------

type Delta = [ TyName ]
type Gamma = [ (ValName, Ty) ]

data Ctx = Ctx { getDelta :: Delta , getGamma :: Gamma }
emptyCtx = Ctx { getDelta = [], getGamma = [] }

checkTyVar :: Ctx -> TyName -> M ()
checkTyVar g v = do
    if List.elem v (getDelta g) then
      return ()
    else
      throwError $ "Type variable not found " ++ (show v)

lookupTmVar :: Ctx -> ValName -> M Ty
lookupTmVar g v = do
    case lookup v (getGamma g) of
      Just s -> return s
      Nothing -> throwError $ "Term variable notFound " ++ (show v)

extendTy :: TyName -> Ctx -> Ctx
extendTy n ctx = ctx { getDelta =  n : (getDelta ctx) }

extendTys :: [TyName] -> Ctx -> Ctx
extendTys ns ctx = foldr extendTy ctx ns

extendTm :: ValName -> Ty -> Ctx -> Ctx
extendTm n ty ctx = ctx { getGamma = (n, ty) : (getGamma ctx) }

extendTms :: [ValName] -> [Ty] -> Ctx -> Ctx
extendTms [] [] ctx = ctx
extendTms (n:ns) (ty:tys) ctx = extendTm n ty (extendTms ns tys ctx)

{-
extendDecl :: Decl -> Ctx -> Ctx
extendDecl (DeclVar x (Embed (Ann _ ty))) = extendTm x ty
extendDecl (DeclPrj i x (Embed (Ann _ (TyProd tys)))) = extendTm x (tys !! i)                                           
extendDecl (DeclPrim x  _) = extendTm x TyInt
extendDecl (DeclUnpack b x (Embed (Ann _ (Exists bnd)))) = 
  extendTy b . extendTm x (patUnbind b bnd)
-}


tcty :: Ctx -> Ty -> M ()
tcty g  (TyVar x) =
   checkTyVar g x
tcty g  (All b) = do
   (xs, tys) <- unbind b
   let g' = extendTys xs g -- XX
   mapM_ (tcty g') tys
tcty g TyInt =  return ()
tcty g (TyProd tys) = do
   mapM_ (tcty g . fst) tys
tcty g (Exists b) = do 
  (a, ty) <- unbind b
  tcty (extendTy a g) ty


typecheckHeapVal :: Ctx -> Ann HeapVal -> M Ty
typecheckHeapVal g (Ann (Code bnd) (All bnd')) = do  
  mb  <- unbind2 bnd bnd' -- may fail
  case mb of 
    Just (as, bnd2, _, tys) -> do
      (xs, e) <- unbind bnd2
      let g' = extendTys as g
      mapM_ (tcty g') tys
      typecheck (extendTms xs tys g') e
      return (All bnd')
    Nothing -> throwError "wrong # of type variables"
  
typecheckHeapVal g (Ann (Tuple es) ty) = do 
  tys <- mapM (typecheckAnnVal g) es
  let ty' = TyProd $ map (,Un) tys 
  if ty `aeq` ty' 
    then return ty
    else throwError "incorrect annotation on tuple"

typecheckVal :: Ctx -> Val -> M Ty
typecheckVal g (TmVar x) = lookupTmVar g x
typecheckVal g (TmInt i)    = return TyInt
typecheckVal g (TApp av ty) = do
  tcty g ty
  ty' <- typecheckAnnVal g av
  case ty' of 
    All bnd -> do 
      (as, bs) <- unbind bnd
      case as of 
        [] -> throwError "can't instantiate non-polymorphic function"
        (a:as') -> do
          let bs' = subst a ty bs
          return (All (bind as' bs'))

typecheckAnnVal g (Ann (Pack ty1 av) ty) = do
  case ty of 
    Exists bnd -> do 
      (a, ty2) <- unbind bnd
      tcty g ty1
      ty' <- typecheckAnnVal g av
      if (not (ty' `aeq` subst a ty1 ty2)) 
         then throwError "type error"
         else return ty     
typecheckAnnVal g (Ann v ty) = do  
  tcty g ty
  ty' <- typecheckVal g v 
  if (ty `aeq` ty') 
     then return ty
     else throwError $ "wrong annotation on: " ++ pp v ++ "\nInferred: " ++ pp ty' ++ "\nAnnotated: " ++ pp ty 

typecheckDecl g (DeclVar x (Embed av)) = do
  ty <- typecheckAnnVal g av
  return $ extendTm x ty g
typecheckDecl g (DeclPrj i x (Embed av@(Ann v _))) = do
  ty <- typecheckAnnVal g av
  case ty of 
    TyProd tys | i < length tys -> 
      return $ extendTm x (fst (tys !! i)) g
    _ -> throwError "cannot project"
typecheckDecl g (DeclPrim x (Embed (av1, _, av2))) = do
  ty1 <- typecheckAnnVal g av1
  ty2 <- typecheckAnnVal g av2
  case (ty1 , ty2) of 
    (TyInt, TyInt) -> return $ extendTm x TyInt g
    _ -> throwError "TypeError"
typecheckDecl g (DeclUnpack a x (Embed av)) = do
  tya <- typecheckAnnVal g av
  case tya of 
    Exists bnd -> do 
      let ty = patUnbind a bnd 
      return $ extendTy a (extendTm x ty g)
    _ -> throwError "TypeError"
typecheckDecl g (DeclMalloc x (Embed tys)) = do                
  mapM_ (tcty g) tys
  return $ extendTm x (TyProd (map (,Un) tys)) g      
typecheckDecl g (DeclAssign x (Embed (av1@(Ann v1 _), i, av2))) = do
  ty1 <- typecheckAnnVal g av1 
  ty2 <- typecheckAnnVal g av2
  case ty1 of 
    TyProd tys | i < length tys -> 
      let (xs,(ty,_):ys) = splitAt i tys in
      if ty `aeq` ty2 
        then return $ extendTm x (TyProd (xs ++ (ty,Init) : ys)) g
        else throwError "TypeError"
         
typecheck :: Ctx -> Tm -> M ()
typecheck g (Let bnd) = do
  (d,e) <- unbind bnd
  g' <- typecheckDecl g d
  typecheck g' e
typecheck g (App av es) = do
  ty <- typecheckAnnVal g av
  case ty of
   (All bnd) -> do
     (as, argtys) <- unbind bnd
     argtys' <- mapM (typecheckAnnVal g) es
     if length as /= 0 
       then throwError "must use type application"
       else 
         if (length argtys /= length argtys') 
           then throwError "incorrect args"
           else if (not (all id (zipWith aeq argtys argtys'))) then 
              throwError "arg mismatch"
              else return ()
typecheck g (TmIf0 av e1 e2) = do
  ty0 <- typecheckAnnVal g av
  typecheck g e1
  typecheck g e2
  if ty0 `aeq` TyInt then 
    return ()
  else   
    throwError "TypeError"
typecheck g (Halt ty av) = do
  ty' <- typecheckAnnVal g av
  if (not (ty `aeq` ty'))
    then throwError "type error"
    else return ()
         
         
progcheck (tm, Heap m) = do
  let g = 
        Map.foldlWithKey (\ctx x (Ann _ ty) -> extendTm x ty ctx) 
        emptyCtx m
  mapM_ (typecheckHeapVal g) (Map.elems m)
  typecheck g tm



-----------------------------------------------------------------
-- Small-step semantics
-----------------------------------------------------------------
  
{-
mkSubst :: Decl -> M (Tm,Heap) -> (Tm,Heap)
mkSubst (DeclVar   x (Embed (Ann v _))) = return $ subst x v
mkSubst (DeclPrj i x (Embed (Ann (TmProd avs) _))) | i < length avs =
       let Ann vi _ = avs !! i in return $ subst x vi
mkSubst (DeclPrim  x (Embed (Ann (TmInt i1) _, p, Ann (TmInt i2) _))) = 
       let v = TmInt (evalPrim p i1 i2) in
       return $ subst x v
mkSubst (DeclUnpack a x (Embed (Ann (Pack ty (Ann u _)) _))) = 
  return $ subst a ty . subst x u  
mkSubst (DeclPrj i x (Embed av)) = 
  throwError $ "invalid prj " ++ pp i ++ ": " ++ pp av
mkSubst (DeclUnpack a x (Embed av)) = 
  throwError $ "invalid unpack:" ++ pp av



step :: (Tm, Heap) -> M (Tm, Heap)

step (Let bnd, heap) = do
  (d, e) <- unbind bnd
  ss     <- mkSubst d
  return $ ss (e, heap)
      
step (App (Ann e1@(Fix bnd) _) avs) = do
    ((f, as), bnd2) <- unbind bnd
    (xtys, e) <- unbind bnd2
    let us = map (\(Ann u _) -> u) avs
    let xs = map fst xtys
    return $ substs ((f,e1):(zip xs us)) e

step (TmIf0 (Ann (TmInt i) _) e1 e2) = if i==0 then return e1 else return e2

step _ = throwError "cannot step"
  
evaluate :: Tm -> M Val
evaluate (Halt _ (Ann v _)) = return v
evaluate e = do
  e' <- step e
  evaluate e'
-}  
-----------------------------------------------------------------
-- Pretty-printer
-----------------------------------------------------------------

instance Display Ty where
  display (TyVar n)     = display n
  display (TyInt)       = return $ text "Int"
  display (All bnd) = lunbind bnd $ \ (as,tys) -> do
    da <- displayList as
    dt <- displayList tys
    if null as 
      then return $ parens dt <+> text "-> void"
      else prefix "forall" (brackets da <> text "." <+> parens dt <+> text "-> void")
  display (TyProd tys) = displayTuple tys
  display (Exists bnd) = lunbind bnd $ \ (a,ty) -> do
    da <- display a 
    dt <- display ty
    prefix "exists" (da <> text "." <+> dt)
    
instance Display (Ty, Flag) where    
  display (ty, fl) = do
    dty <- display ty
    let f = case fl of { Un -> "0" ; Init -> "1" }
    return $ dty <> text "^" <> text f
    
instance Display (ValName,Embed Ty) where                         
  display (n, Embed ty) = do
    dn <- display n
    dt <- display ty
    return $ dn <> colon <> dt
    
instance Display Val where                         
  display (TmInt i) = return $ int i
  display (TmVar n) = display n
  display (Pack ty e) = do 
    dty <- display ty
    de  <- display e 
    prefix "pack" (brackets (dty <> comma <> de))
  display (TApp av ty) = do
    dv <- display av
    dt <- display ty
    return $ dv <+> (brackets dt)

instance Display HeapVal where
  display (Code bnd) = lunbind bnd $ \(as, bnd2) -> lunbind bnd2 $ \(xtys, e) -> do
    ds    <- displayList as  
    dargs <- displayList xtys
    de    <- withPrec (precedence "code") $ display e
    let tyArgs = if null as then empty else brackets ds
    let tmArgs = if null xtys then empty else parens dargs
    prefix "code"  (tyArgs <> tmArgs <> text "." $$ de)
    
  display (Tuple es) = displayTuple es
  

instance Display a => Display (Ann a) where
{-  display (Ann av ty) = do
    da <- display av
    dt <- display ty
    return $ parens (da <> text ":" <> dt)  -}
  display (Ann av _) = display av

instance Display Tm where
  display (App av args) = do
    da    <- display av
    dargs <- displayList args
    let tmArgs = if null args then empty else space <> parens dargs
    return $ da <> tmArgs
  display (Halt ty v) = do 
    dv <- display v
    --dt <- display ty
    return $ text "halt" <+> dv -- <+> text ":" <+> dt
  display (Let bnd) = lunbind bnd $ \(d, e) -> do
    dd <- display d
    de <- display e
    return $ (text "let" <+> dd <+> text "in" $$ de)
  display (TmIf0 e0 e1 e2) = do
    d0 <- display e0
    d1 <- display e1
    d2 <- display e2
    prefix "if0" $ parens $ sep [d0 <> comma , d1 <> comma, d2]

instance Display Decl where
  display (DeclVar x (Embed av)) = do
    dx <- display x
    dv <- display av
    return $ dx <+> text "=" <+> dv
  display (DeclPrj i x (Embed av)) = do
    dx <- display x
    dv <- display av
    return $ dx <+> text "=" <+> text "pi" <> int i <+> dv
  display (DeclPrim x (Embed (e1, p, e2))) = do
    dx <- display x
    let str = show p
    d1 <- display e1 
    d2 <- display e2 
    return $ dx <+> text "=" <+> d1 <+> text str <+> d2
  display (DeclUnpack a x (Embed av)) = do
    da <- display a
    dx <- display x
    dav <- display av
    return $ brackets (da <> comma <> dx) <+> text "=" <+> dav
  display (DeclMalloc x (Embed tys)) = do
    dx <- display x
    dtys <- displayTuple tys
    return $ dx <+> text "= malloc" <> dtys
  display (DeclAssign x (Embed (av1, i, av2))) = do
    dx <- display x
    dav1 <- display av1
    dav2 <- display av2
    return $ dx <+> text "=" <+> dav1 <+> brackets (text (show i)) 
      <+>  text "<-" <+> dav2

instance Display Heap where
  display (Heap m) = do
    fcns <- mapM (\(d,v) -> do 
                     dn <- display d
                     dv <- display v
                     return (dn, dv)) (Map.toList m)
    return $ hang (text "letrec") 2 $ 
      vcat [ n <+> text "=" <+> dv | (n,dv) <- fcns ]

instance Display (Tm, Heap) where
  display (tm,h) = do
    dh <- display h
    dt <- display tm
    return $ dh $$ text "in" <+> dt