{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module      : Jikka.Core.Convert.EqualitySolving
-- Description : equality solving. / 等式を解きます
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : hotman78@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- \[
--     \newcommand\int{\mathbf{int}}
--     \newcommand\bool{\mathbf{bool}}
--     \newcommand\list{\mathbf{list}}
-- \]
module Jikka.Core.Convert.EqualitySolving
  ( run,
    rule,

    -- * internal rules
    moveLiteralToRight,
    convertGreaterToLess,
    reduceReflexivity,
    makeRightZero,
    reduceIntInjective,
    reduceNot,
    reduceListCtor,
    reduceListInjective,
  )
where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.QuasiRules
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

-- | `moveLiteralToRight` moves literals to lhs of `(==)` or `(/=)`, using symmetricity.
moveLiteralToRight :: (MonadAlpha m, MonadError Error m) => RewriteRule m
moveLiteralToRight :: RewriteRule m
moveLiteralToRight =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ String -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule String
"equal/symmetricity/literal" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
        Equal' Type
t Expr
x Expr
y | Expr -> Bool
isLiteral Expr
x Bool -> Bool -> Bool
&& Bool -> Bool
not (Expr -> Bool
isLiteral Expr
y) -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr
Equal' Type
t Expr
y Expr
x
        Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing,
      String -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule String
"notequal/symmetricity/literal" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
        Equal' Type
t Expr
x Expr
y | Expr -> Bool
isLiteral Expr
x Bool -> Bool -> Bool
&& Bool -> Bool
not (Expr -> Bool
isLiteral Expr
y) -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr
Equal' Type
t Expr
y Expr
x
        Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing
    ]

-- | `convertGreaterToLess` erases `(>)` and `(>=)`.
convertGreaterToLess :: (MonadAlpha m, MonadError Error m) => RewriteRule m
convertGreaterToLess :: RewriteRule m
convertGreaterToLess =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "greaterthan->lessthan" forall x y. x > y = y < x |],
      [r| "greaterequal->lessequal" forall x y. x >= y = y <= x |]
    ]

-- | `reduceReflexivity` uses reflexivity.
reduceReflexivity :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceReflexivity :: RewriteRule m
reduceReflexivity =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "lessthan/reflexivity" forall x. x == x = false |],
      [r| "lessequal/reflexivity" forall x. x == x = true |],
      [r| "equal/reflexivity" forall x. x == x = true |],
      [r| "notequal/reflexivity" forall x. x == x = false |]
    ]

-- | `makeRightZero` makes RHS of integer equality/inequality zero with subtracting RHS from both sides.
makeRightZero :: (MonadAlpha m, MonadError Error m) => RewriteRule m
makeRightZero :: RewriteRule m
makeRightZero =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ String -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule String
"lessthan/right-zero" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
        LessThan' Type
IntTy Expr
x Expr
y | Expr
y Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer -> Expr
LitInt' Integer
0 -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr
LessThan' Type
IntTy (Expr -> Expr -> Expr
Minus' Expr
x Expr
y) (Integer -> Expr
LitInt' Integer
0)
        Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing,
      String -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule String
"lessequal/right-zero" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
        LessEqual' Type
IntTy Expr
x Expr
y | Expr
y Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer -> Expr
LitInt' Integer
0 -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr
LessEqual' Type
IntTy (Expr -> Expr -> Expr
Minus' Expr
x Expr
y) (Integer -> Expr
LitInt' Integer
0)
        Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing,
      String -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule String
"equal/right-zero" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
        Equal' Type
IntTy Expr
x Expr
y | Expr
y Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer -> Expr
LitInt' Integer
0 -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr
Equal' Type
IntTy (Expr -> Expr -> Expr
Minus' Expr
x Expr
y) (Integer -> Expr
LitInt' Integer
0)
        Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing,
      String -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule String
"notequal/right-zero" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
        NotEqual' Type
IntTy Expr
x Expr
y | Expr
y Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer -> Expr
LitInt' Integer
0 -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr
NotEqual' Type
IntTy (Expr -> Expr -> Expr
Minus' Expr
x Expr
y) (Integer -> Expr
LitInt' Integer
0)
        Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing
    ]

-- | `reduceIntInjective` removes injective functions from equalities of integers.
reduceIntInjective :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceIntInjective :: RewriteRule m
reduceIntInjective =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "equal/negate" forall x y k. - x == 0 = x == 0  |],
      [r| "equal/fact" forall x y. fact x - fact y == 0 = x == y  |],
      [r| "equal/fact'" forall x y. - fact x + fact y == 0 = x == y  |]
    ]

reduceNot :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceNot :: RewriteRule m
reduceNot =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "equal/not" forall x y. not x == y = x /= y |],
      [r| "equal/not'" forall x y. x == not y = x /= y |],
      [r| "notequal/not" forall x y. not x /= y = x == y |],
      [r| "notequal/not'" forall x y. x /= not y = x == y |]
    ]

reduceListCtor :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceListCtor :: RewriteRule m
reduceListCtor =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "equal/nil/nil" forall x xs. nil == nil = true |],
      [r| "equal/cons/nil" forall x xs. cons x xs == nil = false |],
      [r| "equal/nil/cons" forall x xs. nil == cons x xs = false |],
      [r| "equal/cons/cons" forall x xs y ys. cons x xs == cons y ys = x == y && xs == ys |]
    ]

reduceListInjective :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceListInjective :: RewriteRule m
reduceListInjective =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "equal/range/range" forall n1 n2. range n1 == range n2 = n1 == n2 |]
    ]

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
moveLiteralToRight,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
convertGreaterToLess,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceReflexivity,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
makeRightZero,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceIntInjective,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceNot,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceListCtor,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceListInjective
    ]

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.EqualitySolving" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog