{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
module Jikka.Core.Convert.SortAbs
( run,
rule,
)
where
import Control.Monad.Trans.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Convert.Alpha as Alpha
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.QuasiRules
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util
replaceAbsDelta :: VarName -> VarName -> VarName -> Expr -> Expr
replaceAbsDelta :: VarName -> VarName -> VarName -> Expr -> Expr
replaceAbsDelta VarName
x VarName
y VarName
z Expr
e = ([(VarName, Type)] -> Expr -> Expr)
-> [(VarName, Type)] -> Expr -> Expr
mapSubExpr [(VarName, Type)] -> Expr -> Expr
forall p. p -> Expr -> Expr
go [] Expr
e
where
go :: p -> Expr -> Expr
go p
_ = \case
Abs' Expr
e | ArithmeticExpr -> Bool
isZeroArithmeticExpr (Expr -> ArithmeticExpr
parseArithmeticExpr (Expr -> Expr -> Expr
Minus' Expr
e (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
x) (VarName -> Expr
Var VarName
y)))) -> VarName -> Expr
Var VarName
z
Abs' Expr
e | ArithmeticExpr -> Bool
isZeroArithmeticExpr (Expr -> ArithmeticExpr
parseArithmeticExpr (Expr -> Expr -> Expr
Minus' Expr
e (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
y) (VarName -> Expr
Var VarName
x)))) -> VarName -> Expr
Var VarName
z
Expr
e -> Expr
e
swapTwoVars :: MonadAlpha m => VarName -> VarName -> Expr -> m Expr
swapTwoVars :: VarName -> VarName -> Expr -> m Expr
swapTwoVars VarName
x VarName
y Expr
e = do
VarName
x' <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
VarName
y' <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
y
Expr
e <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x (VarName -> Expr
Var VarName
x') Expr
e
Expr
e <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
y (VarName -> Expr
Var VarName
y') Expr
e
Expr
e <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x' (VarName -> Expr
Var VarName
y) Expr
e
VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
y' (VarName -> Expr
Var VarName
x) Expr
e
isSymmetric :: MonadAlpha m => VarName -> VarName -> Expr -> m Bool
isSymmetric :: VarName -> VarName -> Expr -> m Bool
isSymmetric VarName
x VarName
y Expr
f = do
Expr
g <- VarName -> VarName -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> VarName -> Expr -> m Expr
swapTwoVars VarName
x VarName
y Expr
f
Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Expr -> ArithmeticExpr
parseArithmeticExpr Expr
g ArithmeticExpr -> ArithmeticExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr -> ArithmeticExpr
parseArithmeticExpr Expr
f
rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule = String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
makeRewriteRule String
"sum/sum/abs/symmetric" ((RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \RewriteEnvironment
env -> \case
Sum' (Map' Type
IntTy Type
_ (Lam VarName
x Type
_ (Sum' (Map' Type
_ Type
_ (Lam VarName
y Type
_ Expr
f) Expr
xs'))) Expr
xs) | Expr
xs' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
xs -> MaybeT m Expr -> m (Maybe Expr)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m Expr -> m (Maybe Expr))
-> MaybeT m Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ do
VarName
delta <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
let f' :: Expr
f' = VarName -> VarName -> VarName -> Expr -> Expr
replaceAbsDelta VarName
x VarName
y VarName
delta Expr
f
Bool -> MaybeT m ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT m ()) -> Bool -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$ Expr
f' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr
f
Bool -> MaybeT m ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT m ()) -> MaybeT m Bool -> MaybeT m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m Bool -> MaybeT m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VarName -> VarName -> Expr -> m Bool
forall (m :: * -> *).
MonadAlpha m =>
VarName -> VarName -> Expr -> m Bool
isSymmetric VarName
x VarName
y Expr
f')
VarName
ys <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ Expr -> m VarName
forall (m :: * -> *). MonadAlpha m => Expr -> m VarName
genVarName'' Expr
xs
VarName
i <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
VarName
j <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
Expr
lt <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
delta (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
x) (VarName -> Expr
Var VarName
y)) Expr
f'
Expr
eq <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
delta (Integer -> Expr
LitInt' Integer
0) Expr
f'
Expr
gt <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
delta (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
y) (VarName -> Expr
Var VarName
x)) Expr
f'
let ctx :: Expr -> Expr
ctx = VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
IntTy (Type -> Expr -> Expr -> Expr
At' Type
IntTy (VarName -> Expr
Var VarName
ys) (VarName -> Expr
Var VarName
j))
let lt' :: Expr
lt' = Expr -> Expr
Sum' (Type -> Type -> Expr -> Expr -> Expr
Map' Type
IntTy Type
IntTy (VarName -> Type -> Expr -> Expr
Lam VarName
j Type
IntTy (Expr -> Expr
ctx Expr
lt)) (Expr -> Expr
Range1' (VarName -> Expr
Var VarName
i)))
let eq' :: Expr
eq' = VarName -> Type -> Expr -> Expr -> Expr
Let VarName
j Type
IntTy (VarName -> Expr
Var VarName
i) (Expr -> Expr
ctx Expr
eq)
let gt' :: Expr
gt' = Expr -> Expr
Sum' (Type -> Type -> Expr -> Expr -> Expr
Map' Type
IntTy Type
IntTy (VarName -> Type -> Expr -> Expr
Lam VarName
j Type
IntTy (Expr -> Expr
ctx Expr
gt)) (Expr -> Expr -> Expr
Range2' (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Integer -> Expr
LitInt' Integer
1)) (Type -> Expr -> Expr
Len' Type
IntTy (VarName -> Expr
Var VarName
ys))))
let e :: Expr
e =
VarName -> Type -> Expr -> Expr -> Expr
Let VarName
ys (Type -> Type
ListTy Type
IntTy) (Type -> Expr -> Expr
Sorted' Type
IntTy Expr
xs) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
Expr -> Expr
Sum'
( Type -> Type -> Expr -> Expr -> Expr
Map'
Type
IntTy
Type
IntTy
( VarName -> Type -> Expr -> Expr
Lam
VarName
i
Type
IntTy
( VarName -> Type -> Expr -> Expr -> Expr
Let
VarName
x
Type
IntTy
(Type -> Expr -> Expr -> Expr
At' Type
IntTy (VarName -> Expr
Var VarName
ys) (VarName -> Expr
Var VarName
i))
(Expr -> Expr -> Expr
Plus' (Expr -> Expr -> Expr
Plus' Expr
lt' Expr
eq') Expr
gt')
)
)
(Expr -> Expr
Range1' (Type -> Expr -> Expr
Len' Type
IntTy (VarName -> Expr
Var VarName
ys)))
)
m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ [(VarName, Type)] -> Expr -> m Expr
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
[(VarName, Type)] -> Expr -> m Expr
Alpha.runExpr (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env) Expr
e
Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.SortAbs" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog