{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module      : Jikka.Core.Convert.ShortCutFusion
-- Description : does short cut fusion. / short cut fusion を行います。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- \[
--     \newcommand\int{\mathbf{int}}
--     \newcommand\bool{\mathbf{bool}}
--     \newcommand\list{\mathbf{list}}
-- \]
module Jikka.Core.Convert.ShortCutFusion
  ( run,

    -- * internal rules
    rule,
    reduceBuild,
    reduceMapBuild,
    reduceMap,
    reduceMapMap,
    reduceFoldMap,
    reduceFold,
    reduceFoldBuild,
  )
where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Format (formatExpr)
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.QuasiRules
import Jikka.Core.Language.RewriteRules

-- |
-- * `Range1` remains.
-- * `Range2` is removed.
-- * `Range3` is removed.
-- * `Nil` and `Cons` are kept as is.
reduceBuild :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceBuild :: RewriteRule m
reduceBuild =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "range2" forall l r. range2 l r = map (fun i -> l + i) (range (r - l)) |],
      [r| "range3" forall l r step. range3 l r step = map (fun i -> l + i * step) (range ((r - l) /^ step)) |]
    ]

reduceMapBuild :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceMapBuild :: RewriteRule m
reduceMapBuild =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "sorted/nil" sorted nil = nil |],
      [r| "sorted/range" forall n. sorted (range n) = range n |],
      [r| "reversed/nil" reversed nil = nil |],
      [r| "reversed/range" forall n. reversed (range n) = map (fun i -> n - i - 1) (range n) |],
      [r| "filter/nil" filter _ nil = nil |],
      [r| "map/nil" map _ nil = nil |],
      [r| "map/cons" forall f x xs. map f (cons x xs) = cons (f x) (map f xs) |]
    ]

reduceMap :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceMap :: RewriteRule m
reduceMap =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "map/id" forall xs. map (fun x -> x) xs = xs |],
      [r| "filter/const-false" forall xs. filter (fun _ -> false) xs = nil |],
      [r| "filter/const-true" forall xs. filter (fun _ -> true) xs = xs |]
    ]

-- |
-- * Functions are reordered as:
--   * `Sort` and `Reversed` (functions to reorder) are lastly applied to lists
--   * `Map` (functions to modify lists)
--   * `Filter` (funcitons to reduce lengths) is firstly applied to lists
reduceMapMap :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceMapMap :: RewriteRule m
reduceMapMap =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ [r| "map/map" forall f g xs. map g (map f xs) = map (fun x -> g (f x)) xs |],
      [r| "map/reversed" forall f xs. map f (reversed xs) = reversed (map f xs) |],
      [r| "filter/filter" forall f g xs. filter g (filter f xs) = filter (fun x -> f x && g x) xs |],
      [r| "filter/sorted" forall f xs. filter f (sorted xs) = sorted (filter f xs) |],
      [r| "filter/reversed" forall f xs. filter f (reversed xs) = reversed (filter f xs) |],
      [r| "reversed/reversed" forall xs. reversed (reversed xs) = xs |],
      [r| "sorted/reversed" forall xs. sorted (reversed xs) = sorted xs |],
      [r| "sorted/sorted" forall xs. sorted (sorted xs) = sorted xs |]
    ]

reduceFoldMap :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceFoldMap :: RewriteRule m
reduceFoldMap =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ -- reduce `Reversed`
      [r| "len/reversed" forall xs. len (reversed xs) = len xs |],
      [r| "elem/reversed" forall x xs. elem x (reversed xs) = elem x xs |],
      [r| "at/reversed" forall xs i. (reversed xs)[i] = xs[len(xs) - i - 1] |],
      -- reduce `Sorted`
      [r| "len/sorted" forall xs. len (sorted xs) = len xs |],
      [r| "elem/sorted" forall x xs. elem x (sorted xs) = elem x xs |],
      -- reduce `Map`
      [r| "len/map" forall f xs. len (map f xs) = len xs |],
      [r| "at/map" forall f xs i. (map f xs)[i] = f xs[i] |],
      [r| "foldl/map" forall g init f xs. foldl g init (map f xs) = foldl (fun y x -> g y (f x)) init xs|],
      -- others
      [r| "len/setat" forall xs i x. len xs[i <- x] = len xs |],
      [r| "len/scanl" forall f init xs. len (scanl f init xs) = len xs + 1 |],
      [r| "at/setat" forall xs i x j. xs[i <- x][j] = if i == j then x else xs[j] |]
    ]

reduceFold :: Monad m => RewriteRule m
reduceFold :: RewriteRule m
reduceFold = [Char] -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
[Char] -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule [Char]
"foldl->iterate" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
  Foldl' Type
t1 Type
t2 (Lam2 VarName
x2 Type
_ VarName
x1 Type
_ Expr
body) Expr
init Expr
xs | VarName
x1 VarName -> Expr -> Bool
`isUnusedVar` Expr
body -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr -> Expr
Iterate' Type
t2 (Type -> Expr -> Expr
Len' Type
t1 Expr
xs) (VarName -> Type -> Expr -> Expr
Lam VarName
x2 Type
t2 Expr
body) Expr
init
  Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing

reduceFoldBuild :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceFoldBuild :: RewriteRule m
reduceFoldBuild =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ -- reduce `Foldl`
      [r| "foldl/nil" forall f init. foldl f init nil = init |],
      [r| "foldl/cons" forall f init x xs. foldl f init (cons x xs) = foldl f (f init x) xs |],
      -- reduce `Len`
      [r| "len/nil" len nil = 0 |],
      [r| "len/cons" forall x xs. len (cons x xs) = 1 + len xs |],
      [r| "len/range" forall n. len (range n) = n |],
      -- reduce `At`
      [Char] -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
[Char] -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule [Char]
"at/nil" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
        At' Type
t (Nil' Type
_) Expr
i -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> [Char] -> Expr
Bottom' Type
t ([Char] -> Expr) -> [Char] -> Expr
forall a b. (a -> b) -> a -> b
$ [Char]
"cannot subscript empty list: index = " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Expr -> [Char]
formatExpr Expr
i
        Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing,
      [r| "at/cons" forall x xs i. (cons x xs)[i] = if i == 0 then x else xs[i - 1] |],
      [r| "at/range" forall n i. (range n)[i] = i |],
      -- reduce `Elem`
      [r| "elem/nil" forall y. elem y nil = false |],
      [r| "elem/cons" forall y x xs. elem y (cons x xs) = y == x || elem y xs |],
      [r| "elem/range" forall i n. elem i (range n) = 0 <= i && i < n |],
      -- others
      [r| "len/build" forall f base n. len (build f base n) = len base + n |]
    ]

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule =
  [RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
    [ RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceFoldMap,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceMap,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceMapMap,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceFoldBuild,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceMapBuild,
      RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceBuild,
      RewriteRule m
forall (m :: * -> *). Monad m => RewriteRule m
reduceFold
    ]

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

-- | `run` does short cut fusion.
--
-- * This function is mainly for polymorphic reductions. This dosn't do much about concrete things, e.g., arithmetical operations.
-- * This does nothing about `Build`, `Scanl` or `SetAt` except combinations with `Len` or `At`.
--
-- == Example
--
-- Before:
--
-- > length (map f (cons (-1) (range n)))
--
-- After:
--
-- > n + 1
--
-- == List of builtin functions which are reduced
--
-- === Build functions
--
-- * `Nil` \(: \forall \alpha. \list(\alpha)\)
-- * `Cons` \(: \forall \alpha. \alpha \to \list(\alpha) \to \list(\alpha)\)
-- * `Range1` \(: \int \to \list(\int)\)
-- * `Range2` \(: \int \to \int \to \list(\int)\)
-- * `Range3` \(: \int \to \int \to \int \to \list(\int)\)
--
-- === Map functions
--
-- * `Map` \(: \forall \alpha \beta. (\alpha \to \beta) \to \list(\alpha) \to \list(\beta)\)
-- * `Filter` \(: \forall \alpha \beta. (\alpha \to \bool) \to \list(\alpha) \to \list(\beta)\)
-- * `Reversed` \(: \forall \alpha. \list(\alpha) \to \list(\alpha)\)
-- * `Sorted` \(: \forall \alpha. \list(\alpha) \to \list(\alpha)\)
--
-- === Fold functions
--
-- * `Foldl` \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \beta\)
-- * `Len` \(: \forall \alpha. \list(\alpha) \to \int\)
-- * `At` \(: \forall \alpha. \list(\alpha) \to \int \to \alpha\)
-- * `Elem` \(: \forall \alpha. \alpha \to \list(\alpha) \to \bool\)
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = [Char] -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => [Char] -> m a -> m a
wrapError' [Char]
"Jikka.Core.Convert.ShortCutFusion" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog