module TAL where
import Unbound.LocallyNameless hiding (prec,empty,Data,Refl,Val)
import Unbound.LocallyNameless.Alpha
import Unbound.LocallyNameless.Types
import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Data.Monoid (Monoid(..))
import qualified Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map
import Util
import Text.PrettyPrint as PP
type TyName = Name Ty
data Ty = TyVar TyName
| TyInt
| All (Bind [TyName] Gamma)
| TyProd [(Ty, Flag)]
| Exists (Bind TyName Ty)
deriving Show
data Flag = Un | Init
deriving (Eq, Ord, Show)
type Psi = Map Label Ty
type Gamma = [(Register, Ty)]
newtype Register = Register String deriving (Eq, Ord)
instance Show Register where
show (Register s) = s
reg1 :: Register
reg1 = Register "r1"
rtmp :: Int -> Register
rtmp i = Register ("rt" ++ show i)
instance Enum Register where
toEnum i = Register ("r" ++ show i)
fromEnum (Register ('r' : str)) = read str
newtype Label = Label (Name Heap) deriving (Eq, Ord)
instance Show Label where
show (Label n) = show n
data TyApp a = TyApp a Ty deriving Show
sapps :: SmallVal -> [Ty] -> SmallVal
sapps a tys = foldr (\ ty a -> SApp (TyApp a ty)) a tys
data Pack a = Pack Ty a Ty deriving Show
data WordVal = LabelVal Label
| TmInt Int
| Junk Ty
| WApp (TyApp WordVal)
| WPack (Pack WordVal)
deriving Show
data SmallVal = RegVal Register
| WordVal WordVal
| SApp (TyApp SmallVal)
| SPack (Pack SmallVal)
deriving Show
data HeapVal =
Tuple [WordVal]
| Code [TyName] Gamma InstrSeq
deriving Show
type Heap = Map Label HeapVal
type RegisterFile = Map Register WordVal
data Instruction =
Add Register Register SmallVal
| Bnz Register SmallVal
| Ld Register Register Int
| Malloc Register [Ty]
| Mov Register SmallVal
| Mul Register Register SmallVal
| St Register Int Register
| Sub Register Register SmallVal
| Unpack TyName Register SmallVal
deriving Show
data InstrSeq =
Seq Instruction InstrSeq
| Jump SmallVal
| Halt Ty
deriving Show
type Machine = (Heap, RegisterFile, InstrSeq)
$(derive [''Ty, ''Flag, ''Register, ''Label, ''TyApp, ''Pack,
''WordVal, ''SmallVal, ''HeapVal, ''Instruction,
''InstrSeq])
instance Alpha Flag
instance Alpha Ty
instance Alpha Register
instance Alpha Label
instance Alpha a => Alpha (TyApp a)
instance Alpha a => Alpha (Pack a)
instance Alpha WordVal
instance Alpha SmallVal
instance Alpha HeapVal
instance Alpha Instruction
instance Alpha InstrSeq
instance Alpha b => Alpha (Map Register b)
instance Subst Ty Ty where
isvar (TyVar x) = Just (SubstName x)
isvar _ = Nothing
instance Subst Ty Flag
instance (Subst Ty a) => Subst Ty (TyApp a)
instance (Subst Ty a) => Subst Ty (Pack a)
instance Subst Ty WordVal
instance Subst Ty SmallVal
instance Subst Ty HeapVal
instance Subst Ty Instruction
instance Subst Ty InstrSeq
instance Subst Ty Label
instance Subst Ty Register
instance (Rep a, Subst Ty b) => Subst Ty (Map a b)
freshForHeap :: Heap -> Label
freshForHeap h = Label (makeName str (i+1)) where
Label nm = maximum (Map.keys h)
(str, i) = (name2String nm, name2Integer nm)
getIntReg :: RegisterFile -> Register -> M Int
getIntReg r rs =
case Map.lookup rs r of
Just (TmInt i) -> return i
Just _ -> throwError "register not an int"
Nothing -> throwError "register not found"
arith :: (Int -> Int -> Int) -> RegisterFile ->
Register -> SmallVal -> M WordVal
arith op r rs v = do
i <- getIntReg r rs
(wv,_) <- loadReg r v
case wv of
TmInt j -> return (TmInt (i `op` j))
_ -> throwError
$ "arith: word val " ++ pp wv ++" is not an int"
loadReg :: RegisterFile -> SmallVal -> M (WordVal, [Ty])
loadReg r (RegVal rs) = case Map.lookup rs r of
Just w -> return (w, [])
Nothing -> throwError "register val not found"
loadReg r (WordVal w) = return (w, [])
loadReg r (SApp (TyApp sv ty)) = do
(w, tys) <- loadReg r sv
return (w, ty:tys)
loadReg r (SPack (Pack t1 sv t2)) = do
(w, tys) <- loadReg r sv
return (WPack (Pack t1 (tyApp w tys) t2), [])
tyApp :: WordVal -> [Ty] -> WordVal
tyApp w [] = w
tyApp w (ty:tys) = tyApp (WApp (TyApp w ty)) tys
jmpReg :: Heap -> RegisterFile -> SmallVal -> M Machine
jmpReg h r v = do
(w,tys) <- loadReg r v
case w of
LabelVal l ->
case (Map.lookup l h) of
Just (Code alphas gamma instrs') -> do
when (length alphas /= length tys) $
throwError "Bnz: wrong # type args"
return (h, r, substs (zip alphas tys) instrs')
_ -> throwError "Bnz: cannot jump, not code"
_ -> throwError "Bnz: cannot jump, not label"
step :: Machine -> M Machine
step (h, r, Add rd rs v `Seq` instrs) = do
v' <- arith (+) r rs v
return (h, Map.insert rd v' r, instrs)
step (h, r, Mul rd rs v `Seq` instrs) = do
v' <- arith (*) r rs v
return (h, Map.insert rd v' r, instrs)
step (h, r, Sub rd rs v `Seq` instrs) = do
v' <- arith () r rs v
return (h, Map.insert rd v' r, instrs)
step (h, r, Bnz rs v `Seq` instrs) = do
case Map.lookup rs r of
Just (TmInt 0) -> return (h, r, instrs)
Just (TmInt _) -> jmpReg h r v
step (h, r, Jump v) = jmpReg h r v
step (h, r, Ld rd rs i `Seq` instrs) = do
case Map.lookup rs r of
Just (LabelVal l) ->
case Map.lookup l h of
Just (Tuple ws) | i < length ws ->
return (h, Map.insert rd (ws !! i) r, instrs)
_ -> throwError "ld: Cannot load location"
_ -> throwError "ld: not label"
step (h, r, Malloc rd tys `Seq` instrs) = do
let l = freshForHeap h
return (Map.insert l (Tuple (map Junk tys)) h,
Map.insert rd (LabelVal l) r,
instrs)
step (h, r, Mov rd v `Seq` instrs) = do
(w,tys) <- loadReg r v
return (h, Map.insert rd (tyApp w tys) r, instrs)
step (h, r, St rd i rs `Seq` instrs) = do
case Map.lookup rs r of
Just w' ->
case Map.lookup rd r of
Just (LabelVal l) ->
case Map.lookup l h of
Just (Tuple ws) | i < length ws -> do
let (ws0,(_:ws1)) = splitAt i ws
return
(Map.insert l (Tuple (ws0 ++ (w':ws1))) h,
r, instrs)
_ -> throwError "heap label not found or wrong val"
_ -> throwError "register not found or wrong val"
_ -> throwError "register not found"
step (h, r, Unpack alpha rd v `Seq` instrs) = do
(w0, tys) <- loadReg r v
case tyApp w0 tys of
WPack (Pack ty w _) ->
return (h, Map.insert rd w r, subst alpha ty instrs)
_ -> throwError "not a pack"
run :: Machine -> M Machine
run m@(h, r, Halt t) = return m
run m = do
m' <- step m
run m'
type Delta = [ TyName ]
data Ctx = Ctx { getDelta :: Delta ,
getGamma :: Gamma ,
getPsi :: Psi }
emptyCtx = Ctx { getDelta = [],
getGamma = [],
getPsi = Map.empty }
checkTyVar :: Ctx -> TyName -> M ()
checkTyVar g v = do
if List.elem v (getDelta g) then
return ()
else
throwError $ "Type variable not found " ++ (show v)
extendTy :: TyName -> Ctx -> Ctx
extendTy n ctx = ctx { getDelta = n : (getDelta ctx) }
extendTys :: [TyName] -> Ctx -> Ctx
extendTys ns ctx = foldr extendTy ctx ns
insertGamma :: Register -> Ty -> Gamma -> Gamma
insertGamma r ty [] = [(r,ty)]
insertGamma r ty ((r',ty'):rest) | r < r' = (r',ty') : insertGamma r ty rest
insertGamma r ty ((r',ty'):rest) | r == r' = (r,ty) : rest
insertGamma r ty rest = (r,ty) : rest
lookupHeapLabel :: Ctx -> Label -> M Ty
lookupHeapLabel ctx v = do
case Map.lookup v (getPsi ctx) of
Just s -> return s
Nothing -> throwError $ "Label not found " ++ (show v)
lookupReg :: Ctx -> Register -> M Ty
lookupReg ctx v = do
case lookup v (getGamma ctx) of
Just s -> return s
Nothing -> throwError $ "Register not found " ++ (show v)
tcty :: Ctx -> Ty -> M ()
tcty ctx (TyVar x) =
checkTyVar ctx x
tcty ctx (All b) = do
(xs, reg) <- unbind b
let ctx' = extendTys xs ctx
tcGamma ctx' reg
tcty ctx TyInt = return ()
tcty ctx (TyProd tys) = do
mapM_ (tcty ctx . fst) tys
tcty ctx (Exists b) = do
(a, ty) <- unbind b
tcty (extendTy a ctx) ty
tcPsi :: Ctx -> Psi -> M ()
tcPsi ctx psi = mapM_ (tcty ctx) (Map.elems psi)
tcGamma :: Ctx -> Gamma -> M ()
tcGamma ctx g = mapM_ (tcty ctx) (map snd g)
subtype :: Ctx -> Ty -> Ty -> M ()
subtype ctx (TyVar x) (TyVar y) | x == y = return ()
subtype ctx TyInt TyInt = return ()
subtype ctx (All bnd1) (All bnd2) = do
Just (vs1, g1, vs2, g2) <- unbind2 bnd1 bnd2
subGamma ctx g1 g2
subtype ctx (Exists bnd1) (Exists bnd2) = do
Just (v1, t1, v2, t2) <- unbind2 bnd1 bnd2
subtype ctx t1 t2
subtype ctx (TyProd tfs1) (TyProd tfs2) | (length tfs1 >= length tfs2) = do
zipWithM_ (\ (t1, f1) (t2, f2) ->
if f2 == Un then return ()
else subtype ctx t1 t2) tfs1 tfs2
subtype ctx t1 t2 = throwError $ "not a subtype:" ++ pp t1 ++ "\n" ++ pp t2
subGamma :: Ctx -> Gamma -> Gamma -> M ()
subGamma ctx g1 g2 = do
mapM_ (\(r, t2) -> case lookup r g1 of
Just t1 -> subtype ctx t1 t1
Nothing -> throwError $
"subgamma -- register not found:" ++ show r ++ "\n"
++ pp g1 ++ "\n"
++ pp g2 ++ "\n")
g2
typeCheckHeap :: Heap -> Psi -> M ()
typeCheckHeap h psi = mapM_ tcHeapDecl (Map.assocs h) where
ctx = emptyCtx { getPsi = psi }
tcHeapDecl :: (Label, HeapVal) -> M ()
tcHeapDecl (l,hv) =
case Map.lookup l psi of
Just ty -> tcHeapVal hv ty
Nothing -> throwError $ "heap type not found:" ++ show l
tcTuple (Junk ty', (ty,Un)) =
subtype ctx ty' ty
tcTuple (wv, (ty,Init)) = do
ty' <- tcWordVal ctx wv
subtype ctx ty' ty
tcHeapVal (Tuple wvs) (TyProd tys) | length wvs == length tys = do
mapM_ tcTuple (zip wvs tys)
tcHeapVal (Code as g is) _ = do
let ctx = Ctx as g psi
tcInstrSeq ctx is
tcHeapVal _ _ = throwError $ "wrong type for heap val"
tcWordVal :: Ctx -> WordVal -> M Ty
tcWordVal ctx (LabelVal l) = lookupHeapLabel ctx l
tcWordVal ctx (TmInt i) = return TyInt
tcWordVal ctx (Junk ty') = throwError $ "BUG: no Junk here"
tcWordVal ctx (WApp tapp) = tcApp tcWordVal ctx tapp
tcWordVal ctx (WPack pack) = tcPack tcWordVal ctx pack
tcApp :: (Ctx -> a -> M Ty) -> Ctx -> TyApp a -> M Ty
tcApp f ctx (TyApp wv ty) = do
tcty ctx ty
ty' <- f ctx wv
case ty' of
All bnd -> do
(as, bs) <- unbind bnd
case as of
[] -> throwError "can't instantiate non-polymorphic function"
(a:as') -> do
let bs' = subst a ty bs
return (All (bind as' bs'))
tcPack :: Display a => (Ctx -> a -> M Ty) -> Ctx -> Pack a -> M Ty
tcPack f ctx (Pack ty1 wv ty) = do
case ty of
Exists bnd -> do
(a, ty2) <- unbind bnd
tcty ctx ty1
ty' <- f ctx wv
if (not (ty' `aeq` subst a ty1 ty2))
then throwError $ "type error in pack " ++ pp wv ++ ":\n"
++ pp ty' ++ "\n"
++ " does not equal\n"
++ pp (subst a ty1 ty2)
else return ty
tcSmallVal :: Ctx -> SmallVal -> M Ty
tcSmallVal ctx (RegVal r) = lookupReg ctx r
tcSmallVal ctx (WordVal wv) = tcWordVal ctx wv
tcSmallVal ctx (SApp app) = tcApp tcSmallVal ctx app
tcSmallVal ctx (SPack pack) = tcPack tcSmallVal ctx pack
tcInstrSeq :: Ctx -> InstrSeq -> M ()
tcInstrSeq ctx (Seq i is) = do
ctx' <- tcInstr ctx i
tcInstrSeq ctx' is
tcInstrSeq ctx (Jump sv) = do
ty <- tcSmallVal ctx sv
case ty of
All bnd ->
let g = patUnbind [] bnd in
subGamma ctx (getGamma ctx) g
tcInstrSeq ctx (Halt ty) = do
ty' <- lookupReg ctx reg1
subtype ctx ty ty'
tcArith :: Ctx -> Register -> Register -> SmallVal -> M Ctx
tcArith ctx rd rs sv = do
ty1 <- lookupReg ctx rs
ty2 <- tcSmallVal ctx sv
unless (ty1 `aeq` TyInt) $ throwError "source reg must be int"
unless (ty2 `aeq` TyInt) $ throwError "immediate must be int"
let g' = insertGamma rd TyInt (getGamma ctx)
return (ctx { getGamma = g' })
tcInstr :: Ctx -> Instruction -> M Ctx
tcInstr ctx i = case i of
(Add rd rs sv) -> tcArith ctx rd rs sv
(Bnz r sv) -> do
ty1 <- lookupReg ctx r
ty2 <- tcSmallVal ctx sv
unless (ty1 `aeq` TyInt) $ throwError "source reg must be int"
case ty2 of
All bnd -> do
let g = patUnbind [] bnd
subGamma ctx (getGamma ctx) g
return ctx
_ -> throwError "must bnz to code label"
(Ld rd rs i) -> do
ty1 <- lookupReg ctx rs
case ty1 of
TyProd tyfs -> do
when (i >= length tyfs) $ throwError "Ld: index out of range"
let (ty,f) = tyfs !! i
unless (f == Init) $ throwError "Ld: load from unitialized field"
let g = insertGamma rd ty (getGamma ctx)
return $ ctx { getGamma = g }
_ -> throwError $ "Ld: not a tuple"
(Malloc rd tys) -> do
let ty = TyProd (map (,Un) tys)
let g = insertGamma rd ty (getGamma ctx)
return $ ctx { getGamma = g }
(Mov rd sv) -> do
ty <- tcSmallVal ctx sv
let g = insertGamma rd ty (getGamma ctx)
return $ ctx { getGamma = g }
(Mul rd rs sv) -> tcArith ctx rd rs sv
(St rd i rs) -> do
ty1 <- lookupReg ctx rd
ty2 <- lookupReg ctx rs
case ty1 of
TyProd tyfs -> do
when (i >= length tyfs) $ throwError "St: index out of range"
let (before, _:after) = List.splitAt i tyfs
let ty = TyProd (before ++ [(ty2,Init)] ++ after)
let g = insertGamma rd ty (getGamma ctx)
return $ ctx { getGamma = g }
_ -> throwError $ "St: rd not a tuple"
(Sub rd rs sv) -> tcArith ctx rd rs sv
(Unpack a rd sv) -> do
when (a `elem` getDelta ctx) $ throwError "Unpack: tyvar not fresh"
ty1 <- tcSmallVal ctx sv
case ty1 of
Exists bnd -> do
let ty = patUnbind a bnd
let g = insertGamma rd ty (getGamma ctx)
return $ ctx { getDelta = a : (getDelta ctx) }{ getGamma = g }
progcheck :: Machine -> M ()
progcheck (heap, regfile, is) = do
let getHeapTy (_,Tuple _ ) = throwError $ "only code to start"
getHeapTy (l,Code as g _) = return $ (l,All (bind as g))
psi_assocs <- mapM getHeapTy (Map.assocs heap)
let psi = Map.fromList psi_assocs
unless (Map.null regfile) $ throwError "must start with empty registers"
let ctx = Ctx [] [] psi
tcPsi ctx psi
tcInstrSeq ctx is
instance Display Ty where
display (TyVar n) = display n
display (TyInt) = return $ text "Int"
display (All bnd) = lunbind bnd $ \ (as,g) -> do
da <- displayList as
dt <- display g
if null as
then return dt
else prefix "forall" (brackets da <> text "." <+> dt)
display (TyProd tys) = displayTuple tys
display (Exists bnd) = lunbind bnd $ \ (a,ty) -> do
da <- display a
dt <- display ty
prefix "exists" (da <> text "." <+> dt)
instance Display (Ty, Flag) where
display (ty, fl) = do
dty <- display ty
let f = case fl of { Un -> "0" ; Init -> "1" }
return $ dty <> text "^" <> text f
instance Display a => Display (Map Register a) where
display m = do
fcns <- mapM (\(r,v) -> do
dv <- display v
return (r, dv)) (Map.toList m)
return $ braces (sep (punctuate comma
[ text (show n)
<+> text ":" <+> dv | (n,dv) <- fcns ]))
instance Display a => Display [(Register, a)] where
display m = do
fcns <- mapM (\(r,v) -> do
dv <- display v
return (r, dv)) m
return $ braces (sep (punctuate comma
[ text (show n)
<+> text ":" <+> dv | (n,dv) <- fcns ]))
instance Display a => Display (Pack a) where
display (Pack ty e _) = do
dty <- display ty
de <- display e
prefix "pack" (brackets (dty <> comma <> de))
instance Display a => Display (TyApp a) where
display (TyApp av ty) = do
dv <- display av
dt <- display ty
return $ dv <+> (brackets dt)
instance Display WordVal where
display (LabelVal l) = return $ text ( show l)
display (TmInt i) = return $ int i
display (Junk ty) = return $ text "?"
display (WPack p) = display p
display (WApp a) = display a
instance Display SmallVal where
display (RegVal r) = return (text $ show r)
display (WordVal n) = display n
display (SPack p) = display p
display (SApp a) = display a
instance Display HeapVal where
display (Code as gamma is) = do
ds <- displayList as
dargs <- display gamma
de <- display is
let tyArgs = if null as then empty else brackets ds
prefix "code" (tyArgs <> dargs <> text "." $$ de)
display (Tuple es) = displayTuple es
dispArith str rd rs sv = do
dv <- display sv
return $ text str <+> text (show rd)
<> comma <> text (show rs) <> comma <+> dv
instance Display Instruction where
display i = case i of
Add rd rs sv -> dispArith "add" rd rs sv
Bnz r sv -> do
dv <- display sv
return $ text "bnz" <+> text (show r) <> comma <> dv
(Ld rd rs i) ->
return $ text "ld" <+> text (show rd) <> comma <> text (show rs)
<> brackets (int i)
(Malloc rd tys) -> do
dtys <- displayList tys
return $ text "malloc" <+> text (show rd) <> comma <> brackets dtys
(Mov rd sv) -> do
dv <- display sv
return $ text "mov" <+> text (show rd) <> comma <> dv
(Mul rd rs sv) -> dispArith "mul" rd rs sv
(St rd i rs) -> do
return $ text "st" <+> text (show rd) <> brackets (int i) <> comma
<> text (show rs)
(Sub rd rs sv) -> dispArith "sub" rd rs sv
(Unpack a rd sv) -> do
dv <- display sv
return $ text "unpack"
<> brackets (text (show a) <> comma <> text (show rd))
<> comma <> dv
instance Display InstrSeq where
display (Seq i is) = do
di <- display i
dis <- display is
return $ di $+$ dis
display (Jump sv) = do
ds <- display sv
return $ text "jmp" <+> ds
display (Halt _) = do
return $ text "halt"
instance Display Label where
display l = return (text (show l))
instance Display a => Display (Map Label a) where
display m = do
fcns <- mapM (\(d,v) -> do
dn <- display d
dv <- display v
return (dn, dv)) (Map.toList m)
return $ vcat [ n <+> text ":" $$ nest 4 dv | (n,dv) <- fcns ]
instance Display (Heap, RegisterFile, InstrSeq) where
display (h, r, is) = do
dh <- display h
dr <- display r
di <- display is
return $ dh $$ dr $$ text "main:" $$ nest 4 di