{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.ConvexHullTrick
-- Description : uses convex hull trick. / convex hull trick を使います。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- \[
--     \newcommand\int{\mathbf{int}}
--     \newcommand\bool{\mathbf{bool}}
--     \newcommand\list{\mathbf{list}}
-- \]
module Jikka.Core.Convert.ConvexHullTrick
  ( run,

    -- * internal rules
    rule,
    parseLinearFunctionBody,
    parseLinearFunctionBody',
  )
where

import Control.Monad.Trans.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Convert.Alpha as Alpha
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

-- | This is something commutative because only one kind of @c@ is allowed.
plusPair :: (ArithmeticExpr, ArithmeticExpr) -> (ArithmeticExpr, ArithmeticExpr) -> Maybe (ArithmeticExpr, ArithmeticExpr)
plusPair :: (ArithmeticExpr, ArithmeticExpr)
-> (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
plusPair (ArithmeticExpr
a1, ArithmeticExpr
c1) (ArithmeticExpr
a2, ArithmeticExpr
_) | ArithmeticExpr -> Bool
isZeroArithmeticExpr ArithmeticExpr
a2 = (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
forall a. a -> Maybe a
Just (ArithmeticExpr
a1, ArithmeticExpr
c1)
plusPair (ArithmeticExpr
a1, ArithmeticExpr
c1) (ArithmeticExpr
_, ArithmeticExpr
c2) | ArithmeticExpr -> Bool
isZeroArithmeticExpr ArithmeticExpr
c2 = (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
forall a. a -> Maybe a
Just (ArithmeticExpr
a1, ArithmeticExpr
c1)
plusPair (ArithmeticExpr
a1, ArithmeticExpr
_) (ArithmeticExpr
a2, ArithmeticExpr
c2) | ArithmeticExpr -> Bool
isZeroArithmeticExpr ArithmeticExpr
a1 = (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
forall a. a -> Maybe a
Just (ArithmeticExpr
a2, ArithmeticExpr
c2)
plusPair (ArithmeticExpr
_, ArithmeticExpr
c1) (ArithmeticExpr
a2, ArithmeticExpr
c2) | ArithmeticExpr -> Bool
isZeroArithmeticExpr ArithmeticExpr
c1 = (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
forall a. a -> Maybe a
Just (ArithmeticExpr
a2, ArithmeticExpr
c2)
plusPair (ArithmeticExpr
a1, ArithmeticExpr
c1) (ArithmeticExpr
a2, ArithmeticExpr
c2) =
  let (Integer
k1, ArithmeticExpr
c1') = ArithmeticExpr -> (Integer, ArithmeticExpr)
splitConstantFactorArithmeticExpr ArithmeticExpr
c1
      (Integer
k2, ArithmeticExpr
c2') = ArithmeticExpr -> (Integer, ArithmeticExpr)
splitConstantFactorArithmeticExpr ArithmeticExpr
c2
      a1' :: ArithmeticExpr
a1' = ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr (Integer -> ArithmeticExpr
integerArithmeticExpr Integer
k1) ArithmeticExpr
a1
      a2' :: ArithmeticExpr
a2' = ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr (Integer -> ArithmeticExpr
integerArithmeticExpr Integer
k2) ArithmeticExpr
a2
   in if ArithmeticExpr
c1' ArithmeticExpr -> ArithmeticExpr -> Bool
forall a. Eq a => a -> a -> Bool
== ArithmeticExpr
c2'
        then (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
forall a. a -> Maybe a
Just (ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
plusArithmeticExpr ArithmeticExpr
a1' ArithmeticExpr
a2', ArithmeticExpr
c1')
        else Maybe (ArithmeticExpr, ArithmeticExpr)
forall a. Maybe a
Nothing

sumPairs :: [(ArithmeticExpr, ArithmeticExpr)] -> Maybe (ArithmeticExpr, ArithmeticExpr)
sumPairs :: [(ArithmeticExpr, ArithmeticExpr)]
-> Maybe (ArithmeticExpr, ArithmeticExpr)
sumPairs = ((ArithmeticExpr, ArithmeticExpr)
 -> Maybe (ArithmeticExpr, ArithmeticExpr)
 -> Maybe (ArithmeticExpr, ArithmeticExpr))
-> Maybe (ArithmeticExpr, ArithmeticExpr)
-> [(ArithmeticExpr, ArithmeticExpr)]
-> Maybe (ArithmeticExpr, ArithmeticExpr)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(ArithmeticExpr, ArithmeticExpr)
e1 Maybe (ArithmeticExpr, ArithmeticExpr)
e2 -> (ArithmeticExpr, ArithmeticExpr)
-> (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
plusPair (ArithmeticExpr, ArithmeticExpr)
e1 ((ArithmeticExpr, ArithmeticExpr)
 -> Maybe (ArithmeticExpr, ArithmeticExpr))
-> Maybe (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Maybe (ArithmeticExpr, ArithmeticExpr)
e2) ((ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
forall a. a -> Maybe a
Just (Integer -> ArithmeticExpr
integerArithmeticExpr Integer
1, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0))

-- | `parseLinearFunctionBody'` parses the body of a linear function which can be decomposed to convex hull trick.
-- @parseLinearFunctionBody' f i j e@ finds a 4-tuple @a, b, c, d@ where @e = a(f[j], j) c(f[< i], i) + b(f[j], j) + d(f[< i], i)@.
--
-- TODO: What is the relation between @j@ and @k@?
parseLinearFunctionBody' :: VarName -> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' :: VarName
-> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' VarName
f VarName
i VarName
j Expr
e = (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> (Expr, Expr, Expr, Expr)
result ((ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
 -> (Expr, Expr, Expr, Expr))
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe (Expr, Expr, Expr, Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go Expr
e
  where
    result :: (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> (Expr, Expr, Expr, Expr)
result (ArithmeticExpr
a, ArithmeticExpr
c, ArithmeticExpr
b, ArithmeticExpr
d) =
      let (Integer
k, ArithmeticExpr
a') = ArithmeticExpr -> (Integer, ArithmeticExpr)
splitConstantFactorArithmeticExpr ArithmeticExpr
a
          c' :: ArithmeticExpr
c' = ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr (Integer -> ArithmeticExpr
integerArithmeticExpr Integer
k) ArithmeticExpr
c
       in (ArithmeticExpr -> Expr
formatArithmeticExpr ArithmeticExpr
a', ArithmeticExpr -> Expr
formatArithmeticExpr ArithmeticExpr
c', ArithmeticExpr -> Expr
formatArithmeticExpr ArithmeticExpr
b, ArithmeticExpr -> Expr
formatArithmeticExpr ArithmeticExpr
d)
    go :: Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go = \case
      Negate' Expr
e -> do
        (ArithmeticExpr
a, ArithmeticExpr
c, ArithmeticExpr
b, ArithmeticExpr
d) <- Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go Expr
e
        (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticExpr
a, ArithmeticExpr -> ArithmeticExpr
negateArithmeticExpr ArithmeticExpr
c, ArithmeticExpr -> ArithmeticExpr
negateArithmeticExpr ArithmeticExpr
b, ArithmeticExpr -> ArithmeticExpr
negateArithmeticExpr ArithmeticExpr
d)
      Plus' Expr
e1 Expr
e2 -> do
        (ArithmeticExpr
a1, ArithmeticExpr
c1, ArithmeticExpr
b1, ArithmeticExpr
d1) <- Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go Expr
e1
        (ArithmeticExpr
a2, ArithmeticExpr
c2, ArithmeticExpr
b2, ArithmeticExpr
d2) <- Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go Expr
e2
        (ArithmeticExpr
a, ArithmeticExpr
c) <- (ArithmeticExpr, ArithmeticExpr)
-> (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
plusPair (ArithmeticExpr
a1, ArithmeticExpr
c1) (ArithmeticExpr
a2, ArithmeticExpr
c2)
        (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticExpr
a, ArithmeticExpr
c, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
plusArithmeticExpr ArithmeticExpr
b1 ArithmeticExpr
b2, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
plusArithmeticExpr ArithmeticExpr
d1 ArithmeticExpr
d2)
      Minus' Expr
e1 Expr
e2 -> do
        (ArithmeticExpr
a1, ArithmeticExpr
c1, ArithmeticExpr
b1, ArithmeticExpr
d1) <- Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go Expr
e1
        (ArithmeticExpr
a2, ArithmeticExpr
c2, ArithmeticExpr
b2, ArithmeticExpr
d2) <- Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go Expr
e2
        (ArithmeticExpr
a, ArithmeticExpr
c) <- (ArithmeticExpr, ArithmeticExpr)
-> (ArithmeticExpr, ArithmeticExpr)
-> Maybe (ArithmeticExpr, ArithmeticExpr)
plusPair (ArithmeticExpr
a1, ArithmeticExpr
c1) (ArithmeticExpr -> ArithmeticExpr
negateArithmeticExpr ArithmeticExpr
a2, ArithmeticExpr
c2)
        (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticExpr
a, ArithmeticExpr
c, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
minusArithmeticExpr ArithmeticExpr
b1 ArithmeticExpr
b2, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
minusArithmeticExpr ArithmeticExpr
d1 ArithmeticExpr
d2)
      Mult' Expr
e1 Expr
e2 -> do
        (ArithmeticExpr
a1, ArithmeticExpr
c1, ArithmeticExpr
b1, ArithmeticExpr
d1) <- Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go Expr
e1
        (ArithmeticExpr
a2, ArithmeticExpr
c2, ArithmeticExpr
b2, ArithmeticExpr
d2) <- Expr
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
go Expr
e2
        (ArithmeticExpr
a, ArithmeticExpr
c) <-
          [(ArithmeticExpr, ArithmeticExpr)]
-> Maybe (ArithmeticExpr, ArithmeticExpr)
sumPairs
            [ (ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr ArithmeticExpr
a1 ArithmeticExpr
a2, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr ArithmeticExpr
c1 ArithmeticExpr
c2),
              (ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr ArithmeticExpr
b2 ArithmeticExpr
a1, ArithmeticExpr
c1),
              (ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr ArithmeticExpr
b1 ArithmeticExpr
a2, ArithmeticExpr
c2),
              (ArithmeticExpr
a1, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr ArithmeticExpr
c1 ArithmeticExpr
d2),
              (ArithmeticExpr
a2, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr ArithmeticExpr
c2 ArithmeticExpr
d1),
              (ArithmeticExpr
b2, ArithmeticExpr
d1),
              (ArithmeticExpr
b1, ArithmeticExpr
d2)
            ]
        (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticExpr
a, ArithmeticExpr
c, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr ArithmeticExpr
b1 ArithmeticExpr
b2, ArithmeticExpr -> ArithmeticExpr -> ArithmeticExpr
multArithmeticExpr ArithmeticExpr
d1 ArithmeticExpr
d2)
      Expr
e
        | VarName
f VarName -> Expr -> Bool
`isUnusedVar` Expr
e Bool -> Bool -> Bool
&& VarName
j VarName -> Expr -> Bool
`isUnusedVar` Expr
e ->
          -- NOTE: Put constants to @d@ and simplify @a, b@
          (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticExpr
integerArithmeticExpr Integer
1, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0, Expr -> ArithmeticExpr
parseArithmeticExpr Expr
e)
      Expr
e
        | VarName
f VarName -> Expr -> Bool
`isUnusedVar` Expr
e Bool -> Bool -> Bool
&& VarName
i VarName -> Expr -> Bool
`isUnusedVar` Expr
e ->
          (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticExpr
integerArithmeticExpr Integer
1, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0, Expr -> ArithmeticExpr
parseArithmeticExpr Expr
e, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0)
      e :: Expr
e@(At' Type
_ (Var VarName
f') Expr
index) | VarName
f' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
f -> case ArithmeticExpr -> Maybe (VarName, Integer)
unNPlusKPattern (Expr -> ArithmeticExpr
parseArithmeticExpr Expr
index) of
        Just (VarName
i', Integer
k) | VarName
i' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
&& Integer
k Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 -> do
          (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticExpr
integerArithmeticExpr Integer
1, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0, Expr -> ArithmeticExpr
parseArithmeticExpr Expr
e)
        Just (VarName
j', Integer
0) | VarName
j' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
j -> do
          (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
-> Maybe
     (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticExpr
integerArithmeticExpr Integer
1, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0, Expr -> ArithmeticExpr
parseArithmeticExpr Expr
e, Integer -> ArithmeticExpr
integerArithmeticExpr Integer
0)
        Maybe (VarName, Integer)
_ -> Maybe
  (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall a. Maybe a
Nothing
      Expr
_ -> Maybe
  (ArithmeticExpr, ArithmeticExpr, ArithmeticExpr, ArithmeticExpr)
forall a. Maybe a
Nothing

parseLinearFunctionBody :: MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
parseLinearFunctionBody :: VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
parseLinearFunctionBody VarName
f VarName
i Integer
k = MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
 -> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr)))
-> (Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go
  where
    goMin :: f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMin f Expr
e VarName
j Expr
step Expr
size = case ArithmeticExpr -> Maybe (VarName, Integer)
unNPlusKPattern (Expr -> ArithmeticExpr
parseArithmeticExpr Expr
size) of
      Just (VarName
i', Integer
k') | VarName
i' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
&& Integer
k' Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
k -> do
        (Expr
a, Expr
b, Expr
c, Expr
d) <- Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe (Maybe (Expr, Expr, Expr, Expr)
 -> MaybeT m (Expr, Expr, Expr, Expr))
-> Maybe (Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr)
forall a b. (a -> b) -> a -> b
$ VarName
-> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' VarName
f VarName
i VarName
j Expr
step
        -- raname @j@ to @i@
        Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
a
        Expr
c <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
c
        (Expr, Expr, Expr, Expr, Expr, f Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Expr
LitInt' Integer
1, Expr
a, Expr
b, Expr
c, Expr
d, (Expr -> Expr -> Expr
`Minus'` Expr
d) (Expr -> Expr) -> f Expr -> f Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Expr
e)
      Maybe (VarName, Integer)
_ -> Maybe (Expr, Expr, Expr, Expr, Expr, f Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr, f Expr)
forall a. Maybe a
Nothing
    goMax :: f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMax f Expr
e VarName
j Expr
step Expr
size = do
      (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d, f Expr
e) <- f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
forall (m :: * -> *) (f :: * -> *).
(MonadAlpha m, Functor f) =>
f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMin f Expr
e VarName
j Expr
step Expr
size
      (Expr, Expr, Expr, Expr, Expr, f Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr
Negate' Expr
sign, Expr
a, Expr -> Expr
Negate' Expr
b, Expr -> Expr
Negate' Expr
c, Expr
d, Expr -> Expr
Negate' (Expr -> Expr) -> f Expr -> f Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f Expr
e)
    go :: Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go = \case
      Min1' Type
_ (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size)) -> Maybe Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) (f :: * -> *).
(MonadAlpha m, Functor f) =>
f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMin Maybe Expr
forall a. Maybe a
Nothing VarName
j Expr
step Expr
size
      Max1' Type
_ (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size)) -> Maybe Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) (f :: * -> *).
(MonadAlpha m, Functor f) =>
f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMax Maybe Expr
forall a. Maybe a
Nothing VarName
j Expr
step Expr
size
      Min1' Type
_ (Cons' Type
_ Expr
e (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size))) -> Maybe Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) (f :: * -> *).
(MonadAlpha m, Functor f) =>
f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMin (Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e) VarName
j Expr
step Expr
size
      Max1' Type
_ (Cons' Type
_ Expr
e (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size))) -> Maybe Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) (f :: * -> *).
(MonadAlpha m, Functor f) =>
f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMax (Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e) VarName
j Expr
step Expr
size
      Min1' Type
_ (Snoc' Type
_ (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size)) Expr
e) -> Maybe Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) (f :: * -> *).
(MonadAlpha m, Functor f) =>
f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMin (Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e) VarName
j Expr
step Expr
size
      Max1' Type
_ (Snoc' Type
_ (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size)) Expr
e) -> Maybe Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) (f :: * -> *).
(MonadAlpha m, Functor f) =>
f Expr
-> VarName
-> Expr
-> Expr
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, f Expr)
goMax (Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e) VarName
j Expr
step Expr
size
      Negate' Expr
e -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d, Maybe Expr
e) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go Expr
e
        (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr
Negate' Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr
Negate' Expr
d, Maybe Expr
e)
      Plus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d, Maybe Expr
e) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go Expr
e1
        (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Plus' Expr
d Expr
e2, Maybe Expr
e)
      Plus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d, Maybe Expr
e) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go Expr
e2
        (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Plus' Expr
e1 Expr
d, Maybe Expr
e)
      Minus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d, Maybe Expr
e) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go Expr
e1
        (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Minus' Expr
d Expr
e2, Maybe Expr
e)
      Minus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d, Maybe Expr
e) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go Expr
e2
        (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr
Negate' Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Minus' Expr
e1 Expr
d, Maybe Expr
e)
      Mult' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d, Maybe Expr
e) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go Expr
e1
        (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr -> Expr
Mult' Expr
sign Expr
e2, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Mult' Expr
d Expr
e2, Maybe Expr
e)
      Mult' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d, Maybe Expr
e) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
go Expr
e2
        (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr -> Expr
Mult' Expr
e1 Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Mult' Expr
e1 Expr
d, Maybe Expr
e)
      Expr
_ -> Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall a. Maybe a
Nothing

getLength :: Expr -> Maybe Integer
getLength :: Expr -> Maybe Integer
getLength = \case
  Nil' Type
_ -> Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
0
  Cons' Type
_ Expr
_ Expr
xs -> Integer -> Integer
forall a. Enum a => a -> a
succ (Integer -> Integer) -> Maybe Integer -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> Maybe Integer
getLength Expr
xs
  Snoc' Type
_ Expr
xs Expr
_ -> Integer -> Integer
forall a. Enum a => a -> a
succ (Integer -> Integer) -> Maybe Integer -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> Maybe Integer
getLength Expr
xs
  Expr
_ -> Maybe Integer
forall a. Maybe a
Nothing

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule = String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
makeRewriteRule String
"Jikka.Core.Convert.ConvexHullTrick" ((RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \RewriteEnvironment
env -> \case
  -- build (fun f -> step(f)) base n
  Build' Type
IntTy (Lam VarName
f Type
_ Expr
step) Expr
base Expr
n -> MaybeT m Expr -> m (Maybe Expr)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m Expr -> m (Maybe Expr))
-> MaybeT m Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ do
    let ts :: [Type]
ts = [Type
ConvexHullTrickTy, Type -> Type
ListTy Type
IntTy]
    VarName
i <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
    Integer
k <- Maybe Integer -> MaybeT m Integer
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe (Maybe Integer -> MaybeT m Integer)
-> Maybe Integer -> MaybeT m Integer
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Integer
getLength Expr
base
    Expr
step <- VarName -> VarName -> Integer -> Expr -> MaybeT m Expr
forall (m :: * -> *).
MonadError Error m =>
VarName -> VarName -> Integer -> Expr -> m Expr
replaceLenF VarName
f VarName
i Integer
k Expr
step
    -- step(f) = sign() * min (cons e(f, i) (map (fun j -> a(f, j) c(f, i) + b(f, j)) (range (i + k)))) + d(f, i)
    (Expr
sign, Expr
a, Expr
c, Expr
b, Expr
d, Maybe Expr
e) <- m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
 -> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
-> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
-> MaybeT m (Expr, Expr, Expr, Expr, Expr, Maybe Expr)
forall a b. (a -> b) -> a -> b
$ VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
forall (m :: * -> *).
MonadAlpha m =>
VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr, Maybe Expr))
parseLinearFunctionBody VarName
f VarName
i Integer
k Expr
step
    -- Update base when k = 0. If user's program has no bugs, it uses min(cons(x, xs)) when k = 0.
    (Expr
base, Expr
n, Integer
k, Expr
c, Expr
d, Maybe Expr
e) <- case (Maybe Expr
e, Integer
k) of
      (Just Expr
e, Integer
0) -> do
        Expr
e0 <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (Integer -> Expr
LitInt' Integer
0) Expr
e
        Expr
d0 <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (Integer -> Expr
LitInt' Integer
0) Expr
d
        let e0' :: Expr
e0' = VarName -> Type -> Expr -> Expr -> Expr
Let VarName
f (Type -> Type
ListTy Type
IntTy) Expr
base Expr
e0
        let base' :: Expr
base' = Type -> Expr -> Expr -> Expr
Snoc' Type
IntTy Expr
base (Expr -> Expr -> Expr
Plus' (Expr -> Expr -> Expr
Mult' Expr
sign Expr
e0') Expr
d0)
        Expr
c <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Integer -> Expr
LitInt' Integer
1)) Expr
c
        Expr
d <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Integer -> Expr
LitInt' Integer
1)) Expr
d
        Expr
e <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Integer -> Expr
LitInt' Integer
1)) Expr
e
        (Expr, Expr, Integer, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Integer, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
base', Expr -> Expr -> Expr
Minus' Expr
n (Integer -> Expr
LitInt' Integer
1), Integer
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1, Expr
c, Expr
d, Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
e)
      (Maybe Expr, Integer)
_ -> (Expr, Expr, Integer, Expr, Expr, Maybe Expr)
-> MaybeT m (Expr, Expr, Integer, Expr, Expr, Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
base, Expr
n, Integer
k, Expr
c, Expr
d, Maybe Expr
e)
    -- base' = (cht, base)
    Expr
base' <- do
      VarName
x <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
      VarName
f' <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
f
      VarName
i' <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
i
      Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
f (VarName -> Expr
Var VarName
f') Expr
a
      Expr
b <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
f (VarName -> Expr
Var VarName
f') Expr
b
      Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (VarName -> Expr
Var VarName
i') Expr
a
      Expr
b <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (VarName -> Expr
Var VarName
i') Expr
b
      -- cht for base[0], ..., base[k - 1]
      let cht :: Expr
cht = Type -> Type -> Expr -> Expr -> Expr -> Expr
Foldl' Type
IntTy Type
ConvexHullTrickTy (VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x Type
ConvexHullTrickTy VarName
i' Type
IntTy (Expr -> Expr -> Expr -> Expr
ConvexHullTrickInsert' (VarName -> Expr
Var VarName
x) Expr
a Expr
b)) Expr
ConvexHullTrickInit' (Expr -> Expr
Range1' (Integer -> Expr
LitInt' Integer
k))
      Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$
        VarName -> Type -> Expr -> Expr -> Expr
Let VarName
f' (Type -> Type
ListTy Type
IntTy) Expr
base (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
          Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) [Expr
cht, VarName -> Expr
Var VarName
f']
    -- step' = fun (cht, f) i ->
    --     let f' = setat f index(i) value(..)
    --     in let cht' = update cht a(i) b(i)
    --     in (cht', f')
    Expr
step' <- do
      VarName
x <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
      -- value(..) = (min e (min cht f[i + k] + c(i)))
      let value :: Expr
value = Expr -> Expr -> Expr
Plus' (Expr -> Expr -> Expr
Mult' Expr
sign ((Expr -> Expr)
-> (Expr -> Expr -> Expr) -> Maybe Expr -> Expr -> Expr
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Expr -> Expr
forall a. a -> a
id (\Expr
e -> Type -> Expr -> Expr -> Expr
Min2' Type
IntTy Expr
e) Maybe Expr
e (Expr -> Expr -> Expr
ConvexHullTrickGetMin' ([Type] -> Integer -> Expr -> Expr
Proj' [Type]
ts Integer
0 (VarName -> Expr
Var VarName
x)) Expr
c))) Expr
d
      VarName
y <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
      VarName
f' <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
f
      Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
f (VarName -> Expr
Var VarName
f') Expr
a
      Expr
b <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
f (VarName -> Expr
Var VarName
f') Expr
b
      Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Integer -> Expr
LitInt' Integer
k)) Expr
a
      Expr
b <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
i (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Integer -> Expr
LitInt' Integer
k)) Expr
b
      Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$
        VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x ([Type] -> Type
TupleTy [Type]
ts) VarName
i Type
IntTy (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
          VarName -> Type -> Expr -> Expr -> Expr
Let VarName
f (Type -> Type
ListTy Type
IntTy) ([Type] -> Integer -> Expr -> Expr
Proj' [Type]
ts Integer
1 (VarName -> Expr
Var VarName
x)) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
            VarName -> Type -> Expr -> Expr -> Expr
Let VarName
f' (Type -> Type
ListTy Type
IntTy) (Type -> Expr -> Expr -> Expr
Snoc' Type
IntTy (VarName -> Expr
Var VarName
f) Expr
value) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
              VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
ConvexHullTrickTy (Expr -> Expr -> Expr -> Expr
ConvexHullTrickInsert' ([Type] -> Integer -> Expr -> Expr
Proj' [Type]
ts Integer
0 (VarName -> Expr
Var VarName
x)) Expr
a Expr
b) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
                Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) [VarName -> Expr
Var VarName
y, VarName -> Expr
Var VarName
f']
    -- proj 1 (foldl step' base' (range (n - 1)))
    let e :: Expr
e = [Type] -> Integer -> Expr -> Expr
Proj' [Type]
ts Integer
1 (Type -> Type -> Expr -> Expr -> Expr -> Expr
Foldl' Type
IntTy ([Type] -> Type
TupleTy [Type]
ts) Expr
step' Expr
base' (Expr -> Expr
Range1' Expr
n))
    m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ [(VarName, Type)] -> Expr -> m Expr
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
[(VarName, Type)] -> Expr -> m Expr
Alpha.runExpr (RewriteEnvironment -> [(VarName, Type)]
typeEnv RewriteEnvironment
env) Expr
e
  Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

-- | `run` optimizes a DP which has the recurrence relation
-- \[
--     \mathrm{dp}(i) = \min a(j) x(i) + b(j) \lbrace \mid j \lt i \rbrace + c(i)
-- \] where only appropriate elements of \(\mathrm{dp}\) are used in \(a, x, b, c\).
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.ConvexHullTrick" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog