module Cryptol.Transform.MonoValues (rewModule) where
import Cryptol.ModuleSystem.Name (SupplyT,liftSupply,Supply,mkDeclared)
import Cryptol.Parser.Position (emptyRange)
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.TypeMap
import Cryptol.Utils.Ident (ModName)
import Data.List(sortBy,groupBy)
import Data.Either(partitionEithers)
import Data.Map (Map)
import MonadLib hiding (mapM)
import Prelude ()
import Prelude.Compat
newtype RewMap' a = RM (Map Name (TypesMap (Map Int a)))
type RewMap = RewMap' Name
instance TrieMap RewMap' (Name,[Type],Int) where
emptyTM = RM emptyTM
nullTM (RM m) = nullTM m
lookupTM (x,ts,n) (RM m) = do tM <- lookupTM x m
tP <- lookupTM ts tM
lookupTM n tP
alterTM (x,ts,n) f (RM m) = RM (alterTM x f1 m)
where
f1 Nothing = do a <- f Nothing
return (insertTM ts (insertTM n a emptyTM) emptyTM)
f1 (Just tM) = Just (alterTM ts f2 tM)
f2 Nothing = do a <- f Nothing
return (insertTM n a emptyTM)
f2 (Just pM) = Just (alterTM n f pM)
unionTM f (RM a) (RM b) = RM (unionTM (unionTM (unionTM f)) a b)
toListTM (RM m) = [ ((x,ts,n),y) | (x,tM) <- toListTM m
, (ts,pM) <- toListTM tM
, (n,y) <- toListTM pM ]
mapMaybeWithKeyTM f (RM m) =
RM (mapWithKeyTM (\qn tm ->
mapWithKeyTM (\tys is ->
mapMaybeWithKeyTM (\i a -> f (qn,tys,i) a) is) tm) m)
rewModule :: Supply -> Module -> (Module,Supply)
rewModule s m = runM body (mName m) s
where
body = do ds <- mapM (rewDeclGroup emptyTM) (mDecls m)
return m { mDecls = ds }
type M = ReaderT RO (SupplyT Id)
type RO = ModName
newName :: M Name
newName =
do ns <- ask
liftSupply (mkDeclared ns "$mono" Nothing emptyRange)
newTopOrLocalName :: M Name
newTopOrLocalName = newName
inLocal :: M a -> M a
inLocal = id
rewE :: RewMap -> Expr -> M Expr
rewE rews = go
where
tryRewrite (EVar x,tps,n) =
do y <- lookupTM (x,tps,n) rews
return (EVar y)
tryRewrite _ = Nothing
go expr =
case expr of
ETApp e t -> case tryRewrite (splitTApp expr 0) of
Nothing -> ETApp <$> go e <*> return t
Just yes -> return yes
EProofApp e -> case tryRewrite (splitTApp e 1) of
Nothing -> EProofApp <$> go e
Just yes -> return yes
EList es t -> EList <$> mapM go es <*> return t
ETuple es -> ETuple <$> mapM go es
ERec fs -> ERec <$> (forM fs $ \(f,e) -> do e1 <- go e
return (f,e1))
ESel e s -> ESel <$> go e <*> return s
EIf e1 e2 e3 -> EIf <$> go e1 <*> go e2 <*> go e3
EComp len t e mss -> EComp len t <$> go e <*> mapM (mapM (rewM rews)) mss
EVar _ -> return expr
ETAbs x e -> ETAbs x <$> go e
EApp e1 e2 -> EApp <$> go e1 <*> go e2
EAbs x t e -> EAbs x t <$> go e
EProofAbs x e -> EProofAbs x <$> go e
EWhere e dgs -> EWhere <$> go e <*> inLocal
(mapM (rewDeclGroup rews) dgs)
rewM :: RewMap -> Match -> M Match
rewM rews ma =
case ma of
From x len t e -> From x len t <$> rewE rews e
Let d -> Let <$> rewD rews d
rewD :: RewMap -> Decl -> M Decl
rewD rews d = do e <- rewDef rews (dDefinition d)
return d { dDefinition = e }
rewDef :: RewMap -> DeclDef -> M DeclDef
rewDef rews (DExpr e) = DExpr <$> rewE rews e
rewDef _ DPrim = return DPrim
rewDeclGroup :: RewMap -> DeclGroup -> M DeclGroup
rewDeclGroup rews dg =
case dg of
NonRecursive d -> NonRecursive <$> rewD rews d
Recursive ds ->
do let (leave,rew) = partitionEithers (map consider ds)
rewGroups = groupBy sameTParams
$ sortBy compareTParams rew
ds1 <- mapM (rewD rews) leave
ds2 <- mapM rewSame rewGroups
return $ Recursive (ds1 ++ concat ds2)
where
sameTParams (_,tps1,x,_) (_,tps2,y,_) = tps1 == tps2 && x == y
compareTParams (_,tps1,x,_) (_,tps2,y,_) = compare (x,tps1) (y,tps2)
consider d =
case dDefinition d of
DPrim -> Left d
DExpr e -> let (tps,props,e') = splitTParams e
in if not (null tps) && notFun e'
then Right (d, tps, props, e')
else Left d
rewSame ds =
do new <- forM ds $ \(d,_,_,e) ->
do x <- newName
return (d, x, e)
let (_,tps,props,_) : _ = ds
tys = map (TVar . tpVar) tps
proofNum = length props
addRew (d,x,_) = insertTM (dName d,tys,proofNum) x
newRews = foldr addRew rews new
newDs <- forM new $ \(d,newN,e) ->
do e1 <- rewE newRews e
return ( d
, d { dName = newN
, dSignature = (dSignature d)
{ sVars = [], sProps = [] }
, dDefinition = DExpr e1
}
)
case newDs of
[(f,f')] ->
return [ f { dDefinition =
let newBody = EVar (dName f')
newE = EWhere newBody
[ Recursive [f'] ]
in DExpr $ foldr ETAbs
(foldr EProofAbs newE props) tps
}
]
_ -> do tupName <- newTopOrLocalName
let (polyDs,monoDs) = unzip newDs
tupAr = length monoDs
addTPs = flip (foldr ETAbs) tps
. flip (foldr EProofAbs) props
tupD = Decl
{ dName = tupName
, dSignature =
Forall tps props $
TCon (TC (TCTuple tupAr))
(map (sType . dSignature) monoDs)
, dDefinition =
DExpr $
addTPs $
EWhere (ETuple [ EVar (dName d) | d <- monoDs ])
[ Recursive monoDs ]
, dPragmas = []
, dInfix = False
, dFixity = Nothing
, dDoc = Nothing
}
mkProof e _ = EProofApp e
mkFunDef n f =
f { dDefinition =
DExpr $
addTPs $ ESel ( flip (foldl mkProof) props
$ flip (foldl ETApp) tys
$ EVar tupName
) (TupleSel n (Just tupAr))
}
return (tupD : zipWith mkFunDef [ 0 .. ] polyDs)
splitTParams :: Expr -> ([TParam], [Prop], Expr)
splitTParams e = let (tps, e1) = splitWhile splitTAbs e
(props, e2) = splitWhile splitProofAbs e1
in (tps,props,e2)
splitTApp :: Expr -> Int -> (Expr, [Type], Int)
splitTApp (EProofApp e) n = splitTApp e $! (n + 1)
splitTApp e0 n = let (e1,ts) = splitTy e0 []
in (e1, ts, n)
where
splitTy (ETApp e t) ts = splitTy e (t:ts)
splitTy e ts = (e,ts)
notFun :: Expr -> Bool
notFun (EAbs {}) = False
notFun (EProofAbs _ e) = notFun e
notFun _ = True