module DDC.Core.Transform.Rewrite.Rule
(
BindMode (..)
, isBMSpec
, isBMValue
, RewriteRule (..)
, NamedRewriteRule
, mkRewriteRule
, checkRewriteRule
, Error (..)
, Side (..))
where
import DDC.Core.Transform.Rewrite.Error
import DDC.Core.Transform.Reannotate
import DDC.Core.Transform.TransformUpX
import DDC.Core.Exp
import DDC.Core.Pretty ()
import DDC.Core.Collect
import DDC.Core.Compounds
import DDC.Core.Pretty ()
import DDC.Type.Env (KindEnv, TypeEnv)
import DDC.Base.Pretty
import qualified DDC.Core.Analysis.Usage as U
import qualified DDC.Core.Check as C
import qualified DDC.Core.Collect as C
import qualified DDC.Core.Transform.SpreadX as S
import qualified DDC.Type.Check as T
import qualified DDC.Type.Compounds as T
import qualified DDC.Type.Env as T
import qualified DDC.Type.Equiv as T
import qualified DDC.Type.Predicates as T
import qualified DDC.Type.Subsumes as T
import qualified DDC.Type.Transform.SpreadT as S
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.Set as Set
import qualified DDC.Type.Env as Env
data RewriteRule a n
= RewriteRule
{
ruleBinds :: [(BindMode, Bind n)]
, ruleConstraints :: [Type n]
, ruleLeft :: Exp a n
, ruleLeftHole :: Maybe (Exp a n)
, ruleRight :: Exp a n
, ruleWeakEff :: Maybe (Effect n)
, ruleWeakClo :: [Exp a n]
, ruleFreeVars :: [Bound n]
} deriving (Eq, Show)
type NamedRewriteRule a n
= (String, RewriteRule a n)
instance (Pretty n, Eq n) => Pretty (RewriteRule a n) where
ppr (RewriteRule bs cs lhs hole rhs _ _ _)
= pprBinders bs <> pprConstrs cs <> ppr lhs <> pprHole <> text " = " <> ppr rhs
where pprBinders [] = text ""
pprBinders bs' = foldl1 (<>) (map pprBinder bs') <> text ". "
pprBinder (BMSpec, b) = text "[" <> ppr b <> text "] "
pprBinder (BMValue _, b) = text "(" <> ppr b <> text ") "
pprConstrs [] = text ""
pprConstrs (c:cs') = ppr c <> text " => " <> pprConstrs cs'
pprHole
| Just h <- hole
= text " {" <> ppr h <> text "}"
| otherwise
= text ""
data BindMode
= BMSpec
| BMValue Int
deriving (Eq, Show)
isBMSpec :: BindMode -> Bool
isBMSpec BMSpec = True
isBMSpec _ = False
isBMValue :: BindMode -> Bool
isBMValue (BMValue _) = True
isBMValue _ = False
mkRewriteRule
:: Ord n
=> [(BindMode,Bind n)]
-> [Type n]
-> Exp a n
-> Maybe (Exp a n)
-> Exp a n
-> RewriteRule a n
mkRewriteRule bs cs lhs hole rhs
= RewriteRule bs cs lhs hole rhs Nothing [] []
checkRewriteRule
:: (Ord n, Show n, Pretty n)
=> C.Config n
-> T.Env n
-> T.Env n
-> RewriteRule a n
-> Either (Error a n)
(RewriteRule (C.AnTEC a n) n)
checkRewriteRule config kenv tenv
(RewriteRule bs cs lhs hole rhs _ _ _)
= do
let (kenv', tenv', bs') = extendBinds bs kenv tenv
let csSpread = map (S.spreadT kenv') cs
mapM_ (checkConstraint config kenv') csSpread
(lhs', _, _, _)
<- checkExp config kenv' tenv' Lhs lhs
hole' <- case hole of
Just h
-> do (h',_,_,_) <- checkExp config kenv' tenv' Lhs h
return $ Just h'
Nothing -> return Nothing
let a = annotOfExp lhs
let lhs_full = maybe lhs (XApp a lhs) hole
(lhs_full', tLeft, effLeft, cloLeft)
<- checkExp config kenv' tenv' Lhs lhs_full
(rhs', tRight, effRight, cloRight)
<- checkExp config kenv' tenv' Rhs rhs
let err = ErrorTypeConflict
(tLeft, effLeft, cloLeft)
(tRight, effRight, cloRight)
checkEquiv tLeft tRight err
effWeak <- makeEffectWeakening T.kEffect effLeft effRight err
cloWeak <- makeClosureWeakening config kenv' tenv' lhs_full' rhs'
checkUnmentionedBinders bs' lhs_full'
checkAnonymousBinders bs'
checkValidPattern lhs_full
bs'' <- countBinderUsage bs' rhs
let binds = Set.fromList
$ Maybe.catMaybes
$ map (T.takeSubstBoundOfBind . snd) bs
let freeVars = Set.toList
$ (C.freeX T.empty lhs_full'
`Set.union` C.freeX T.empty rhs)
`Set.difference` binds
return $ RewriteRule
bs'' csSpread
lhs' hole' rhs'
effWeak cloWeak
freeVars
extendBinds
:: Ord n
=> [(BindMode, Bind n)]
-> KindEnv n -> TypeEnv n
-> (T.KindEnv n, T.TypeEnv n, [(BindMode, Bind n)])
extendBinds binds kenv tenv
= go binds kenv tenv []
where
go [] k t acc
= (k,t,acc)
go ((bm,b):bs) k t acc
= let b' = S.spreadX k t b
(k',t') = case bm of
BMSpec -> (T.extend b' k, t)
BMValue _ -> (k, T.extend b' t)
in go bs k' t' (acc ++ [(bm,b')])
checkExp
:: (Ord n, Show n, Pretty n)
=> C.Config n
-> KindEnv n
-> TypeEnv n
-> Side
-> Exp a n
-> Either (Error a n)
(Exp (C.AnTEC a n) n, Type n, Effect n, Closure n)
checkExp defs kenv tenv side xx
= let xx' = S.spreadX kenv tenv xx
in case fst $ C.checkExp defs kenv tenv xx' C.Recon of
Left err -> Left $ ErrorTypeCheck side xx' err
Right rhs -> return rhs
checkConstraint
:: (Ord n, Show n, Pretty n)
=> C.Config n
-> KindEnv n
-> Type n
-> Either (Error a n) (Kind n)
checkConstraint config kenv tt
= case T.checkSpec config kenv tt of
Left _err -> Left $ ErrorBadConstraint tt
Right (_, k)
| T.isWitnessType tt -> return k
| otherwise -> Left $ ErrorBadConstraint tt
checkEquiv
:: Ord n
=> Type n
-> Type n
-> Error a n
-> Either (Error a n) ()
checkEquiv tLeft tRight err
| T.equivT tLeft tRight = return ()
| otherwise = Left err
makeEffectWeakening
:: (Ord n, Show n)
=> Kind n
-> Effect n
-> Effect n
-> Error a n
-> Either (Error a n) (Maybe (Type n))
makeEffectWeakening k effLeft effRight onError
| T.equivT effLeft effRight
= return Nothing
| T.subsumesT k effLeft effRight
= return $ Just effLeft
| otherwise
= Left onError
makeClosureWeakening
:: (Ord n, Pretty n, Show n)
=> C.Config n
-> T.Env n
-> T.Env n
-> Exp (C.AnTEC a n) n
-> Exp (C.AnTEC a n) n
-> Either (Error a n)
[Exp (C.AnTEC a n) n]
makeClosureWeakening config kenv tenv lhs rhs
= let lhs' = removeEffects config kenv tenv lhs
supportLeft = support Env.empty Env.empty lhs'
daLeft = supportDaVar supportLeft
wiLeft = supportWiVar supportLeft
spLeft = supportSpVar supportLeft
rhs' = removeEffects config kenv tenv rhs
supportRight = support Env.empty Env.empty rhs'
daRight = supportDaVar supportRight
wiRight = supportWiVar supportRight
spRight = supportSpVar supportRight
a = annotOfExp lhs
in Right
$ [XVar a u
| u <- Set.toList $ daLeft `Set.difference` daRight ]
++ [XWitness a (WVar a u)
| u <- Set.toList $ wiLeft `Set.difference` wiRight ]
++ [XType a (TVar u)
| u <- Set.toList $ spLeft `Set.difference` spRight ]
removeEffects
:: (Ord n, Pretty n, Show n)
=> C.Config n
-> T.Env n
-> T.Env n
-> Exp a n
-> Exp a n
removeEffects config = transformUpX remove
where
remove kenv _tenv x
| XType a et <- x
, Right (_, k) <- T.checkSpec config kenv et
, T.isEffectKind k
= XType a $ T.tBot T.kEffect
| otherwise
= x
checkUnmentionedBinders
:: (Ord n, Show n)
=> [(BindMode, Bind n)]
-> Exp (C.AnTEC a n) n
-> Either (Error a n) ()
checkUnmentionedBinders bs expr
= let used = C.freeX T.empty expr `Set.union` C.freeT T.empty expr
binds = Set.fromList
$ Maybe.catMaybes
$ map (T.takeSubstBoundOfBind . snd) bs
in if binds `Set.isSubsetOf` used
then return ()
else Left ErrorVarUnmentioned
checkAnonymousBinders
:: [(BindMode, Bind n)]
-> Either (Error a n) ()
checkAnonymousBinders bs
| (b:_) <- filter T.isBAnon $ map snd bs
= Left $ ErrorAnonymousBinder b
| otherwise
= return ()
checkValidPattern :: Exp a n -> Either (Error a n) ()
checkValidPattern expr
= go expr
where go (XVar _ _) = return ()
go (XCon _ _) = return ()
go x@(XLAM _ _ _) = Left $ ErrorNotFirstOrder x
go x@(XLam _ _ _) = Left $ ErrorNotFirstOrder x
go (XApp _ l r) = go l >> go r
go x@(XLet _ _ _) = Left $ ErrorNotFirstOrder x
go x@(XCase _ _ _) = Left $ ErrorNotFirstOrder x
go (XCast _ _ x) = go x
go (XType a t) = go_t a t
go (XWitness _ _) = return ()
go_t _ (TVar _) = return ()
go_t _ (TCon _) = return ()
go_t a t@(TForall _ _) = Left $ ErrorNotFirstOrder (XType a t)
go_t a (TApp l r) = go_t a l >> go_t a r
go_t _ (TSum _) = return ()
countBinderUsage
:: Ord n
=> [(BindMode, Bind n)]
-> Exp a n
-> Either (Error a n) [(BindMode, Bind n)]
countBinderUsage bs x
= let U.UsedMap um
= fst $ annotOfExp $ U.usageX x
get (BMValue _, BName n t)
= (BMValue
$ length
$ Maybe.fromMaybe []
$ Map.lookup n um
, BName n t)
get b
= b
in return $ map get bs
instance Reannotate RewriteRule where
reannotate f (RewriteRule bs cs lhs hole rhs eff clo fv)
= RewriteRule bs cs (re lhs) (fmap re hole) (re rhs) eff (map re clo) fv
where
re = reannotate f