{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.SortAbs
-- Description : remove abs with sorting. / sort によって abs を除去します。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- \[
--     \newcommand\int{\mathbf{int}}
--     \newcommand\bool{\mathbf{bool}}
--     \newcommand\list{\mathbf{list}}
-- \]
module Jikka.Core.Convert.SortAbs
  ( run,

    -- * internal rules
    rule,
  )
where

import Control.Monad.Trans.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Convert.Alpha as Alpha
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint
import Jikka.Core.Language.QuasiRules
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

-- | @replaceAbsDelta x y z e@ replaces \(\levert x - y \rvert\) in \(e\) with \(z\).
replaceAbsDelta :: VarName -> VarName -> VarName -> Expr -> Expr
replaceAbsDelta :: VarName -> VarName -> VarName -> Expr -> Expr
replaceAbsDelta VarName
x VarName
y VarName
z Expr
e = ([(VarName, Type)] -> Expr -> Expr)
-> [(VarName, Type)] -> Expr -> Expr
mapSubExpr [(VarName, Type)] -> Expr -> Expr
forall p. p -> Expr -> Expr
go [] Expr
e
  where
    go :: p -> Expr -> Expr
go p
_ = \case
      Abs' Expr
e | ArithmeticExpr -> Bool
isZeroArithmeticExpr (Expr -> ArithmeticExpr
parseArithmeticExpr (Expr -> Expr -> Expr
Minus' Expr
e (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
x) (VarName -> Expr
Var VarName
y)))) -> VarName -> Expr
Var VarName
z
      Abs' Expr
e | ArithmeticExpr -> Bool
isZeroArithmeticExpr (Expr -> ArithmeticExpr
parseArithmeticExpr (Expr -> Expr -> Expr
Minus' Expr
e (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
y) (VarName -> Expr
Var VarName
x)))) -> VarName -> Expr
Var VarName
z
      Expr
e -> Expr
e

swapTwoVars :: MonadAlpha m => VarName -> VarName -> Expr -> m Expr
swapTwoVars :: VarName -> VarName -> Expr -> m Expr
swapTwoVars VarName
x VarName
y Expr
e = do
  VarName
x' <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
  VarName
y' <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
y
  Expr
e <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x (VarName -> Expr
Var VarName
x') Expr
e
  Expr
e <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
y (VarName -> Expr
Var VarName
y') Expr
e
  Expr
e <- VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x' (VarName -> Expr
Var VarName
y) Expr
e
  VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
y' (VarName -> Expr
Var VarName
x) Expr
e

-- | TODO: accept more functions
isSymmetric :: MonadAlpha m => VarName -> VarName -> Expr -> m Bool
isSymmetric :: VarName -> VarName -> Expr -> m Bool
isSymmetric VarName
x VarName
y Expr
f = do
  Expr
g <- VarName -> VarName -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> VarName -> Expr -> m Expr
swapTwoVars VarName
x VarName
y Expr
f
  Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Expr -> ArithmeticExpr
parseArithmeticExpr Expr
g ArithmeticExpr -> ArithmeticExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr -> ArithmeticExpr
parseArithmeticExpr Expr
f

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule = String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
makeRewriteRule String
"sum/sum/abs/symmetric" ((RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \RewriteEnvironment
env -> \case
  Sum' (Map' Type
IntTy Type
_ (Lam VarName
x Type
_ (Sum' (Map' Type
_ Type
_ (Lam VarName
y Type
_ Expr
f) Expr
xs'))) Expr
xs) | Expr
xs' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
xs -> MaybeT m Expr -> m (Maybe Expr)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m Expr -> m (Maybe Expr))
-> MaybeT m Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ do
    VarName
delta <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
    let f' :: Expr
f' = VarName -> VarName -> VarName -> Expr -> Expr
replaceAbsDelta VarName
x VarName
y VarName
delta Expr
f
    Bool -> MaybeT m ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT m ()) -> Bool -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$ Expr
f' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr
f -- f has |x - y|
    Bool -> MaybeT m ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT m ()) -> MaybeT m Bool -> MaybeT m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m Bool -> MaybeT m Bool
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VarName -> VarName -> Expr -> m Bool
forall (m :: * -> *).
MonadAlpha m =>
VarName -> VarName -> Expr -> m Bool
isSymmetric VarName
x VarName
y Expr
f') -- symmetric
    VarName
ys <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ Expr -> m VarName
forall (m :: * -> *). MonadAlpha m => Expr -> m VarName
genVarName'' Expr
xs
    VarName
i <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
    VarName
j <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
    Expr
lt <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
delta (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
x) (VarName -> Expr
Var VarName
y)) Expr
f'
    Expr
eq <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
delta (Integer -> Expr
LitInt' Integer
0) Expr
f'
    Expr
gt <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
delta (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
y) (VarName -> Expr
Var VarName
x)) Expr
f'
    let ctx :: Expr -> Expr
ctx = VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
IntTy (Type -> Expr -> Expr -> Expr
At' Type
IntTy (VarName -> Expr
Var VarName
ys) (VarName -> Expr
Var VarName
j))
    let lt' :: Expr
lt' = Expr -> Expr
Sum' (Type -> Type -> Expr -> Expr -> Expr
Map' Type
IntTy Type
IntTy (VarName -> Type -> Expr -> Expr
Lam VarName
j Type
IntTy (Expr -> Expr
ctx Expr
lt)) (Expr -> Expr
Range1' (VarName -> Expr
Var VarName
i)))
    let eq' :: Expr
eq' = VarName -> Type -> Expr -> Expr -> Expr
Let VarName
j Type
IntTy (VarName -> Expr
Var VarName
i) (Expr -> Expr
ctx Expr
eq)
    let gt' :: Expr
gt' = Expr -> Expr
Sum' (Type -> Type -> Expr -> Expr -> Expr
Map' Type
IntTy Type
IntTy (VarName -> Type -> Expr -> Expr
Lam VarName
j Type
IntTy (Expr -> Expr
ctx Expr
gt)) (Expr -> Expr -> Expr
Range2' (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Integer -> Expr
LitInt' Integer
1)) (Type -> Expr -> Expr
Len' Type
IntTy (VarName -> Expr
Var VarName
ys))))
    let e :: Expr
e =
          VarName -> Type -> Expr -> Expr -> Expr
Let VarName
ys (Type -> Type
ListTy Type
IntTy) (Type -> Expr -> Expr
Sorted' Type
IntTy Expr
xs) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
            Expr -> Expr
Sum'
              ( Type -> Type -> Expr -> Expr -> Expr
Map'
                  Type
IntTy
                  Type
IntTy
                  ( VarName -> Type -> Expr -> Expr
Lam
                      VarName
i
                      Type
IntTy
                      ( VarName -> Type -> Expr -> Expr -> Expr
Let
                          VarName
x
                          Type
IntTy
                          (Type -> Expr -> Expr -> Expr
At' Type
IntTy (VarName -> Expr
Var VarName
ys) (VarName -> Expr
Var VarName
i))
                          (Expr -> Expr -> Expr
Plus' (Expr -> Expr -> Expr
Plus' Expr
lt' Expr
eq') Expr
gt')
                      )
                  )
                  (Expr -> Expr
Range1' (Type -> Expr -> Expr
Len' Type
IntTy (VarName -> Expr
Var VarName
ys)))
              )
    m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ [(VarName, Type)] -> Expr -> m Expr
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
[(VarName, Type)] -> Expr -> m Expr
Alpha.runExpr (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env) Expr
e
  Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

-- | `run` reduces \(\lvert \sum _ {a_i \in a} \sum _ {a_j \in a} f(a, a_i, a_j) \rvert\) to \(\mathbf{let}~ b = \mathrm{sort}(a) ~\mathbf{in}~ \sum \sum f'(a, a_i, a_j)\) when \(f\) contains \(\lvert a_i - a_j \rvert\) and \(f(a, a_i, a_j) = f(a, a_j, a_i)\) holds.
--
-- == Example
--
-- Before:
--
-- > sum (map (fun (a_i: int) ->
-- >     sum (map (fun (a_j: int) ->
-- >         abs (a_i - a_j)
-- >     ) a)
-- > ) a)
--
-- After:
--
-- > let b = sort a
-- > in sum (map (fun (i: int) ->
-- >     (sum (map (fun (b_j: int) ->
-- >         b_i - b_j
-- >     ) b[:i])
-- >     + 0
-- >     + sum (map (fun (b_j: int) ->
-- >         b_j - b_i
-- >     ) b[i + 1:]))
-- > ) (range (length b)))
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.SortAbs" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog