module DDC.Core.Transform.Rewrite.Rule
(
BindMode (..)
, isBMSpec
, isBMValue
, RewriteRule (..)
, NamedRewriteRule
, mkRewriteRule
, checkRewriteRule
, Error (..)
, Side (..))
where
import DDC.Core.Transform.Rewrite.Error
import DDC.Core.Transform.Reannotate
import DDC.Core.Transform.TransformUpX
import DDC.Core.Exp.Annot
import DDC.Core.Pretty ()
import DDC.Core.Collect
import DDC.Core.Pretty ()
import DDC.Type.Env (KindEnv, TypeEnv)
import DDC.Base.Pretty
import qualified DDC.Core.Analysis.Usage as U
import qualified DDC.Core.Check as C
import qualified DDC.Core.Collect as C
import qualified DDC.Core.Transform.SpreadX as S
import qualified DDC.Type.Check as T
import qualified DDC.Type.Compounds as T
import qualified DDC.Type.Env as T
import qualified DDC.Type.Equiv as T
import qualified DDC.Type.Predicates as T
import qualified DDC.Type.Subsumes as T
import qualified DDC.Type.Transform.SpreadT as S
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.Set as Set
import qualified DDC.Type.Env as Env
data RewriteRule a n
= RewriteRule
{
ruleBinds :: [(BindMode, Bind n)]
, ruleConstraints :: [Type n]
, ruleLeft :: Exp a n
, ruleLeftHole :: Maybe (Exp a n)
, ruleRight :: Exp a n
, ruleWeakEff :: Maybe (Effect n)
, ruleWeakClo :: [Exp a n]
, ruleFreeVars :: [Bound n]
} deriving (Eq, Show)
type NamedRewriteRule a n
= (String, RewriteRule a n)
instance (Pretty n, Eq n) => Pretty (RewriteRule a n) where
ppr (RewriteRule bs cs lhs hole rhs _ _ _)
= pprBinders bs <> pprConstrs cs <> ppr lhs <> pprHole <> text " = " <> ppr rhs
where pprBinders [] = text ""
pprBinders bs' = foldl1 (<>) (map pprBinder bs') <> text ". "
pprBinder (BMSpec, b) = text "[" <> ppr b <> text "] "
pprBinder (BMValue _, b) = text "(" <> ppr b <> text ") "
pprConstrs [] = text ""
pprConstrs (c:cs') = ppr c <> text " => " <> pprConstrs cs'
pprHole
| Just h <- hole
= text " {" <> ppr h <> text "}"
| otherwise
= text ""
data BindMode
= BMSpec
| BMValue Int
deriving (Eq, Show)
isBMSpec :: BindMode -> Bool
isBMSpec BMSpec = True
isBMSpec _ = False
isBMValue :: BindMode -> Bool
isBMValue (BMValue _) = True
isBMValue _ = False
mkRewriteRule
:: Ord n
=> [(BindMode,Bind n)]
-> [Type n]
-> Exp a n
-> Maybe (Exp a n)
-> Exp a n
-> RewriteRule a n
mkRewriteRule bs cs lhs hole rhs
= RewriteRule bs cs lhs hole rhs Nothing [] []
checkRewriteRule
:: (Show a, Ord n, Show n, Pretty n)
=> C.Config n
-> T.Env n
-> T.Env n
-> RewriteRule a n
-> Either (Error a n)
(RewriteRule (C.AnTEC a n) n)
checkRewriteRule config kenv tenv
(RewriteRule bs cs lhs hole rhs _ _ _)
= do
let (kenv', tenv', bs') = extendBinds bs kenv tenv
let csSpread = map (S.spreadT kenv') cs
mapM_ (checkConstraint config kenv') csSpread
(lhs', _, _)
<- checkExp config kenv' tenv' Lhs lhs
hole' <- case hole of
Just h
-> do (h',_,_) <- checkExp config kenv' tenv' Lhs h
return $ Just h'
Nothing -> return Nothing
let a = annotOfExp lhs
let lhs_full = maybe lhs (XApp a lhs) hole
(lhs_full', tLeft, effLeft)
<- checkExp config kenv' tenv' Lhs lhs_full
(rhs', tRight, effRight)
<- checkExp config kenv' tenv' Rhs rhs
let err = ErrorTypeConflict
(tLeft, effLeft, tBot kClosure)
(tRight, effRight, tBot kClosure)
checkEquiv tLeft tRight err
effWeak <- makeEffectWeakening T.kEffect effLeft effRight err
cloWeak <- makeClosureWeakening config kenv' tenv' lhs_full' rhs'
checkUnmentionedBinders bs' lhs_full'
checkAnonymousBinders bs'
checkValidPattern lhs_full
bs'' <- countBinderUsage bs' rhs
let binds = Set.fromList
$ Maybe.catMaybes
$ map (T.takeSubstBoundOfBind . snd) bs
let freeVars = Set.toList
$ (C.freeX T.empty lhs_full'
`Set.union` C.freeX T.empty rhs)
`Set.difference` binds
return $ RewriteRule
bs'' csSpread
lhs' hole' rhs'
effWeak cloWeak
freeVars
extendBinds
:: Ord n
=> [(BindMode, Bind n)]
-> KindEnv n -> TypeEnv n
-> (T.KindEnv n, T.TypeEnv n, [(BindMode, Bind n)])
extendBinds binds kenv tenv
= go binds kenv tenv []
where
go [] k t acc
= (k,t,acc)
go ((bm,b):bs) k t acc
= let b' = S.spreadX k t b
(k',t') = case bm of
BMSpec -> (T.extend b' k, t)
BMValue _ -> (k, T.extend b' t)
in go bs k' t' (acc ++ [(bm,b')])
checkExp
:: (Show a, Ord n, Show n, Pretty n)
=> C.Config n
-> KindEnv n
-> TypeEnv n
-> Side
-> Exp a n
-> Either (Error a n)
(Exp (C.AnTEC a n) n, Type n, Effect n)
checkExp defs kenv tenv side xx
= let xx' = S.spreadX kenv tenv xx
in case fst $ C.checkExp defs kenv tenv C.Recon C.DemandNone xx' of
Left err -> Left $ ErrorTypeCheck side xx' err
Right rhs -> return rhs
checkConstraint
:: (Ord n, Show n, Pretty n)
=> C.Config n
-> KindEnv n
-> Type n
-> Either (Error a n) (Kind n)
checkConstraint config kenv tt
= case T.checkSpec config kenv tt of
Left _err -> Left $ ErrorBadConstraint tt
Right (_, k)
| T.isWitnessType tt -> return k
| otherwise -> Left $ ErrorBadConstraint tt
checkEquiv
:: Ord n
=> Type n
-> Type n
-> Error a n
-> Either (Error a n) ()
checkEquiv tLeft tRight err
| T.equivT tLeft tRight = return ()
| otherwise = Left err
makeEffectWeakening
:: (Ord n, Show n)
=> Kind n
-> Effect n
-> Effect n
-> Error a n
-> Either (Error a n) (Maybe (Type n))
makeEffectWeakening k effLeft effRight onError
| T.equivT effLeft effRight
= return Nothing
| T.subsumesT k effLeft effRight
= return $ Just effLeft
| otherwise
= Left onError
makeClosureWeakening
:: (Ord n, Pretty n, Show n)
=> C.Config n
-> T.Env n
-> T.Env n
-> Exp (C.AnTEC a n) n
-> Exp (C.AnTEC a n) n
-> Either (Error a n)
[Exp (C.AnTEC a n) n]
makeClosureWeakening config kenv tenv lhs rhs
= let lhs' = removeEffects config kenv tenv lhs
supportLeft = support Env.empty Env.empty lhs'
daLeft = supportDaVar supportLeft
wiLeft = supportWiVar supportLeft
spLeft = supportSpVar supportLeft
rhs' = removeEffects config kenv tenv rhs
supportRight = support Env.empty Env.empty rhs'
daRight = supportDaVar supportRight
wiRight = supportWiVar supportRight
spRight = supportSpVar supportRight
a = annotOfExp lhs
in Right
$ [XVar a u
| u <- Set.toList $ daLeft `Set.difference` daRight ]
++ [XWitness a (WVar a u)
| u <- Set.toList $ wiLeft `Set.difference` wiRight ]
++ [XType a (TVar u)
| u <- Set.toList $ spLeft `Set.difference` spRight ]
removeEffects
:: (Ord n, Pretty n, Show n)
=> C.Config n
-> T.Env n
-> T.Env n
-> Exp a n
-> Exp a n
removeEffects config = transformUpX remove
where
remove kenv _tenv x
| XType a et <- x
, Right (_, k) <- T.checkSpec config kenv et
, T.isEffectKind k
= XType a $ T.tBot T.kEffect
| otherwise
= x
checkUnmentionedBinders
:: (Ord n, Show n)
=> [(BindMode, Bind n)]
-> Exp (C.AnTEC a n) n
-> Either (Error a n) ()
checkUnmentionedBinders bs expr
= let used = C.freeX T.empty expr `Set.union` C.freeT T.empty expr
binds = Set.fromList
$ Maybe.catMaybes
$ map (T.takeSubstBoundOfBind . snd) bs
in if binds `Set.isSubsetOf` used
then return ()
else Left ErrorVarUnmentioned
checkAnonymousBinders
:: [(BindMode, Bind n)]
-> Either (Error a n) ()
checkAnonymousBinders bs
| (b:_) <- filter T.isBAnon $ map snd bs
= Left $ ErrorAnonymousBinder b
| otherwise
= return ()
checkValidPattern :: Exp a n -> Either (Error a n) ()
checkValidPattern expr
= go expr
where go (XVar _ _) = return ()
go (XCon _ _) = return ()
go x@(XLAM _ _ _) = Left $ ErrorNotFirstOrder x
go x@(XLam _ _ _) = Left $ ErrorNotFirstOrder x
go (XApp _ l r) = go l >> go r
go x@(XLet _ _ _) = Left $ ErrorNotFirstOrder x
go x@(XCase _ _ _) = Left $ ErrorNotFirstOrder x
go (XCast _ _ x) = go x
go (XType a t) = go_t a t
go (XWitness _ _) = return ()
go_t _ (TVar _) = return ()
go_t _ (TCon _) = return ()
go_t a t@(TForall _ _) = Left $ ErrorNotFirstOrder (XType a t)
go_t a (TApp l r) = go_t a l >> go_t a r
go_t _ (TSum _) = return ()
countBinderUsage
:: Ord n
=> [(BindMode, Bind n)]
-> Exp a n
-> Either (Error a n) [(BindMode, Bind n)]
countBinderUsage bs x
= let U.UsedMap um
= fst $ annotOfExp $ U.usageX x
get (BMValue _, BName n t)
= (BMValue
$ length
$ Maybe.fromMaybe []
$ Map.lookup n um
, BName n t)
get b
= b
in return $ map get bs
instance Reannotate RewriteRule where
reannotate f (RewriteRule bs cs lhs hole rhs eff clo fv)
= RewriteRule bs cs (re lhs) (fmap re hole) (re rhs) eff (map re clo) fv
where
re = reannotate f