{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ViewPatterns #-}
module Jikka.Core.Convert.ShortCutFusion
( run,
rule,
reduceBuild,
reduceMapBuild,
reduceMap,
reduceMapMap,
reduceFoldMap,
reduceFold,
reduceFoldBuild,
)
where
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Format (formatExpr)
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.QuasiRules
import Jikka.Core.Language.RewriteRules
reduceBuild :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceBuild :: RewriteRule m
reduceBuild =
[RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
[ [r| "range2" forall l r. range2 l r = map (fun i -> l + i) (range (r - l)) |],
[r| "range3" forall l r step. range3 l r step = map (fun i -> l + i * step) (range ((r - l) /^ step)) |]
]
reduceMapBuild :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceMapBuild :: RewriteRule m
reduceMapBuild =
[RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
[ [r| "sorted/nil" sorted nil = nil |],
[r| "sorted/range" forall n. sorted (range n) = range n |],
[r| "reversed/nil" reversed nil = nil |],
[r| "reversed/range" forall n. reversed (range n) = map (fun i -> n - i - 1) (range n) |],
[r| "filter/nil" filter _ nil = nil |],
[r| "map/nil" map _ nil = nil |],
[r| "map/cons" forall f x xs. map f (cons x xs) = cons (f x) (map f xs) |]
]
reduceMap :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceMap :: RewriteRule m
reduceMap =
[RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
[ [r| "map/id" forall xs. map (fun x -> x) xs = xs |],
[r| "filter/const-false" forall xs. filter (fun _ -> false) xs = nil |],
[r| "filter/const-true" forall xs. filter (fun _ -> true) xs = xs |]
]
reduceMapMap :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceMapMap :: RewriteRule m
reduceMapMap =
[RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
[ [r| "map/map" forall f g xs. map g (map f xs) = map (fun x -> g (f x)) xs |],
[r| "map/reversed" forall f xs. map f (reversed xs) = reversed (map f xs) |],
[r| "filter/filter" forall f g xs. filter g (filter f xs) = filter (fun x -> f x && g x) xs |],
[r| "filter/sorted" forall f xs. filter f (sorted xs) = sorted (filter f xs) |],
[r| "filter/reversed" forall f xs. filter f (reversed xs) = reversed (filter f xs) |],
[r| "reversed/reversed" forall xs. reversed (reversed xs) = xs |],
[r| "sorted/reversed" forall xs. sorted (reversed xs) = sorted xs |],
[r| "sorted/sorted" forall xs. sorted (sorted xs) = sorted xs |]
]
reduceFoldMap :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceFoldMap :: RewriteRule m
reduceFoldMap =
[RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
[
[r| "len/reversed" forall xs. len (reversed xs) = len xs |],
[r| "elem/reversed" forall x xs. elem x (reversed xs) = elem x xs |],
[r| "at/reversed" forall xs i. (reversed xs)[i] = xs[len(xs) - i - 1] |],
[r| "len/sorted" forall xs. len (sorted xs) = len xs |],
[r| "elem/sorted" forall x xs. elem x (sorted xs) = elem x xs |],
[r| "len/map" forall f xs. len (map f xs) = len xs |],
[r| "at/map" forall f xs i. (map f xs)[i] = f xs[i] |],
[r| "foldl/map" forall g init f xs. foldl g init (map f xs) = foldl (fun y x -> g y (f x)) init xs|],
[r| "len/setat" forall xs i x. len xs[i <- x] = len xs |],
[r| "len/scanl" forall f init xs. len (scanl f init xs) = len xs + 1 |],
[r| "at/setat" forall xs i x j. xs[i <- x][j] = if i == j then x else xs[j] |]
]
reduceFold :: Monad m => RewriteRule m
reduceFold :: RewriteRule m
reduceFold = [Char] -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
[Char] -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule [Char]
"foldl->iterate" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
Foldl' Type
t1 Type
t2 (Lam2 VarName
x2 Type
_ VarName
x1 Type
_ Expr
body) Expr
init Expr
xs | VarName
x1 VarName -> Expr -> Bool
`isUnusedVar` Expr
body -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr -> Expr
Iterate' Type
t2 (Type -> Expr -> Expr
Len' Type
t1 Expr
xs) (VarName -> Type -> Expr -> Expr
Lam VarName
x2 Type
t2 Expr
body) Expr
init
Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing
reduceFoldBuild :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceFoldBuild :: RewriteRule m
reduceFoldBuild =
[RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
[
[r| "foldl/nil" forall f init. foldl f init nil = init |],
[r| "foldl/cons" forall f init x xs. foldl f init (cons x xs) = foldl f (f init x) xs |],
[r| "len/nil" len nil = 0 |],
[r| "len/cons" forall x xs. len (cons x xs) = 1 + len xs |],
[r| "len/range" forall n. len (range n) = n |],
[Char] -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
[Char] -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule [Char]
"at/nil" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
At' Type
t (Nil' Type
_) Expr
i -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> [Char] -> Expr
Bottom' Type
t ([Char] -> Expr) -> [Char] -> Expr
forall a b. (a -> b) -> a -> b
$ [Char]
"cannot subscript empty list: index = " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Expr -> [Char]
formatExpr Expr
i
Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing,
[r| "at/cons" forall x xs i. (cons x xs)[i] = if i == 0 then x else xs[i - 1] |],
[r| "at/range" forall n i. (range n)[i] = i |],
[r| "elem/nil" forall y. elem y nil = false |],
[r| "elem/cons" forall y x xs. elem y (cons x xs) = y == x || elem y xs |],
[r| "elem/range" forall i n. elem i (range n) = 0 <= i && i < n |],
[r| "len/build" forall f base n. len (build f base n) = len base + n |]
]
rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule =
[RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
[ RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceFoldMap,
RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceMap,
RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceMapMap,
RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceFoldBuild,
RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceMapBuild,
RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceBuild,
RewriteRule m
forall (m :: * -> *). Monad m => RewriteRule m
reduceFold
]
runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = [Char] -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => [Char] -> m a -> m a
wrapError' [Char]
"Jikka.Core.Convert.ShortCutFusion" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog