-- SPDX-FileCopyrightText: 2023 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

{-# OPTIONS_HADDOCK not-home #-}

-- | Optimizer rule and ruleset definitions.
module Morley.Michelson.Optimizer.Internal.Ruleset
  ( module Morley.Michelson.Optimizer.Internal.Ruleset
  ) where

import Prelude

import Data.Default (Default(def))
import Data.Map qualified as Map
import Fmt (Buildable(..), (+|), (|+))

import Morley.Michelson.Typed.Instr

-- | Type of a single rewrite rule, wrapped in `newtype`. It takes an
-- instruction and tries to optimize its head (first few instructions). If
-- optimization succeeds, it returns `Just` the optimized instruction, otherwise
-- it returns `Nothing`.
newtype Rule = Rule {Rule
-> forall (inp :: [T]) (out :: [T]).
   Instr inp out -> Maybe (Instr inp out)
unRule :: forall inp out. Instr inp out -> Maybe (Instr inp out)}

-- | Optimization stages. Stages are run in first to last order, each stage has
-- an 'Int' argument, which allows splitting each stage into sub-stages, which
-- will run lowest index to highest. All default rules use sub-stage @0@.
data OptimizationStage
  = OptimizationStagePrepare Int
  | OptimizationStageMain Int
    -- ^ Main optimisation stage, except rules that would interfere with other
    -- rules.
  | OptimizationStageMainExtended Int
    -- ^ All main stage rules.
  | OptimizationStageFixup Int
    -- ^ Post main stage fixups.
  | OptimizationStageRollAdjacent Int
    -- ^ Main stage rules unroll @DROP n@, @PAIR n@, etc into their primitive
    -- counterparts to simplify some optimisations. This stages coalesces them
    -- back.
  deriving stock (OptimizationStage -> OptimizationStage -> Bool
(OptimizationStage -> OptimizationStage -> Bool)
-> (OptimizationStage -> OptimizationStage -> Bool)
-> Eq OptimizationStage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OptimizationStage -> OptimizationStage -> Bool
== :: OptimizationStage -> OptimizationStage -> Bool
$c/= :: OptimizationStage -> OptimizationStage -> Bool
/= :: OptimizationStage -> OptimizationStage -> Bool
Eq, Eq OptimizationStage
Eq OptimizationStage
-> (OptimizationStage -> OptimizationStage -> Ordering)
-> (OptimizationStage -> OptimizationStage -> Bool)
-> (OptimizationStage -> OptimizationStage -> Bool)
-> (OptimizationStage -> OptimizationStage -> Bool)
-> (OptimizationStage -> OptimizationStage -> Bool)
-> (OptimizationStage -> OptimizationStage -> OptimizationStage)
-> (OptimizationStage -> OptimizationStage -> OptimizationStage)
-> Ord OptimizationStage
OptimizationStage -> OptimizationStage -> Bool
OptimizationStage -> OptimizationStage -> Ordering
OptimizationStage -> OptimizationStage -> OptimizationStage
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: OptimizationStage -> OptimizationStage -> Ordering
compare :: OptimizationStage -> OptimizationStage -> Ordering
$c< :: OptimizationStage -> OptimizationStage -> Bool
< :: OptimizationStage -> OptimizationStage -> Bool
$c<= :: OptimizationStage -> OptimizationStage -> Bool
<= :: OptimizationStage -> OptimizationStage -> Bool
$c> :: OptimizationStage -> OptimizationStage -> Bool
> :: OptimizationStage -> OptimizationStage -> Bool
$c>= :: OptimizationStage -> OptimizationStage -> Bool
>= :: OptimizationStage -> OptimizationStage -> Bool
$cmax :: OptimizationStage -> OptimizationStage -> OptimizationStage
max :: OptimizationStage -> OptimizationStage -> OptimizationStage
$cmin :: OptimizationStage -> OptimizationStage -> OptimizationStage
min :: OptimizationStage -> OptimizationStage -> OptimizationStage
Ord)

instance Buildable OptimizationStage where
  build :: OptimizationStage -> Doc
build = \case
    OptimizationStagePrepare Int
n -> Doc
"prepare " Doc -> Doc -> Doc
forall b. FromDoc b => Doc -> Doc -> b
+| Int
n Int -> Doc -> Doc
forall a b. (Buildable a, FromDoc b) => a -> Doc -> b
|+ Doc
""
    OptimizationStageMain Int
n -> Doc
"main " Doc -> Doc -> Doc
forall b. FromDoc b => Doc -> Doc -> b
+| Int
n Int -> Doc -> Doc
forall a b. (Buildable a, FromDoc b) => a -> Doc -> b
|+ Doc
""
    OptimizationStageMainExtended Int
n -> Doc
"main extended " Doc -> Doc -> Doc
forall b. FromDoc b => Doc -> Doc -> b
+| Int
n Int -> Doc -> Doc
forall a b. (Buildable a, FromDoc b) => a -> Doc -> b
|+ Doc
""
    OptimizationStageFixup Int
n -> Doc
"fixup " Doc -> Doc -> Doc
forall b. FromDoc b => Doc -> Doc -> b
+| Int
n Int -> Doc -> Doc
forall a b. (Buildable a, FromDoc b) => a -> Doc -> b
|+ Doc
""
    OptimizationStageRollAdjacent Int
n -> Doc
"roll adjacent " Doc -> Doc -> Doc
forall b. FromDoc b => Doc -> Doc -> b
+| Int
n Int -> Doc -> Doc
forall a b. (Buildable a, FromDoc b) => a -> Doc -> b
|+ Doc
""

-- | A set of optimization stages. Rules at the same sub-stage are applied in
-- arbitrary order. See 'OptimizationStage' for explanation of sub-stages.
--
-- 'Default' ruleset is empty.
newtype Ruleset = Ruleset { Ruleset -> Map OptimizationStage (NonEmpty Rule)
unRuleset :: Map OptimizationStage (NonEmpty Rule) }
  deriving newtype Ruleset
Ruleset -> Default Ruleset
forall a. a -> Default a
$cdef :: Ruleset
def :: Ruleset
Default

instance Semigroup Ruleset where
  Ruleset Map OptimizationStage (NonEmpty Rule)
l <> :: Ruleset -> Ruleset -> Ruleset
<> Ruleset Map OptimizationStage (NonEmpty Rule)
r = Map OptimizationStage (NonEmpty Rule) -> Ruleset
Ruleset (Map OptimizationStage (NonEmpty Rule) -> Ruleset)
-> Map OptimizationStage (NonEmpty Rule) -> Ruleset
forall a b. (a -> b) -> a -> b
$ (NonEmpty Rule -> NonEmpty Rule -> NonEmpty Rule)
-> Map OptimizationStage (NonEmpty Rule)
-> Map OptimizationStage (NonEmpty Rule)
-> Map OptimizationStage (NonEmpty Rule)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith NonEmpty Rule -> NonEmpty Rule -> NonEmpty Rule
forall a. Semigroup a => a -> a -> a
(<>) Map OptimizationStage (NonEmpty Rule)
l Map OptimizationStage (NonEmpty Rule)
r

instance Monoid Ruleset where
  mempty :: Ruleset
mempty = Ruleset
forall a. Default a => a
def

-- | Get rules for a given priority as a list.
rulesAtPrio :: OptimizationStage -> Ruleset -> [Rule]
rulesAtPrio :: OptimizationStage -> Ruleset -> [Rule]
rulesAtPrio OptimizationStage
prio = [Rule]
-> (NonEmpty Rule -> [Rule]) -> Maybe (NonEmpty Rule) -> [Rule]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] NonEmpty Rule -> [Element (NonEmpty Rule)]
NonEmpty Rule -> [Rule]
forall t. Container t => t -> [Element t]
toList (Maybe (NonEmpty Rule) -> [Rule])
-> (Ruleset -> Maybe (NonEmpty Rule)) -> Ruleset -> [Rule]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OptimizationStage
-> Map OptimizationStage (NonEmpty Rule) -> Maybe (NonEmpty Rule)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup OptimizationStage
prio (Map OptimizationStage (NonEmpty Rule) -> Maybe (NonEmpty Rule))
-> (Ruleset -> Map OptimizationStage (NonEmpty Rule))
-> Ruleset
-> Maybe (NonEmpty Rule)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ruleset -> Map OptimizationStage (NonEmpty Rule)
unRuleset

-- | Insert a single rule at a given priority without touching other rules.
insertRuleAtPrio :: OptimizationStage -> Rule -> Ruleset -> Ruleset
insertRuleAtPrio :: OptimizationStage -> Rule -> Ruleset -> Ruleset
insertRuleAtPrio = (Rule -> OptimizationStage -> Ruleset -> Ruleset)
-> OptimizationStage -> Rule -> Ruleset -> Ruleset
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Rule -> OptimizationStage -> Ruleset -> Ruleset)
 -> OptimizationStage -> Rule -> Ruleset -> Ruleset)
-> (Rule -> OptimizationStage -> Ruleset -> Ruleset)
-> OptimizationStage
-> Rule
-> Ruleset
-> Ruleset
forall a b. (a -> b) -> a -> b
$ ([Rule] -> [Rule]) -> OptimizationStage -> Ruleset -> Ruleset
alterRulesAtPrio (([Rule] -> [Rule]) -> OptimizationStage -> Ruleset -> Ruleset)
-> (Rule -> [Rule] -> [Rule])
-> Rule
-> OptimizationStage
-> Ruleset
-> Ruleset
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:)

-- | Remove the stage with the given priority.
clearRulesAtPrio :: OptimizationStage -> Ruleset -> Ruleset
clearRulesAtPrio :: OptimizationStage -> Ruleset -> Ruleset
clearRulesAtPrio = ([Rule] -> [Rule]) -> OptimizationStage -> Ruleset -> Ruleset
alterRulesAtPrio ([Rule] -> [Rule] -> [Rule]
forall a b. a -> b -> a
const [])

-- | Alter all stage rules for a given priority.
alterRulesAtPrio :: ([Rule] -> [Rule]) -> OptimizationStage -> Ruleset -> Ruleset
alterRulesAtPrio :: ([Rule] -> [Rule]) -> OptimizationStage -> Ruleset -> Ruleset
alterRulesAtPrio [Rule] -> [Rule]
f OptimizationStage
prio = Map OptimizationStage (NonEmpty Rule) -> Ruleset
Ruleset (Map OptimizationStage (NonEmpty Rule) -> Ruleset)
-> (Ruleset -> Map OptimizationStage (NonEmpty Rule))
-> Ruleset
-> Ruleset
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (NonEmpty Rule) -> Maybe (NonEmpty Rule))
-> OptimizationStage
-> Map OptimizationStage (NonEmpty Rule)
-> Map OptimizationStage (NonEmpty Rule)
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
Map.alter ([Rule] -> Maybe (NonEmpty Rule)
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty ([Rule] -> Maybe (NonEmpty Rule))
-> (Maybe (NonEmpty Rule) -> [Rule])
-> Maybe (NonEmpty Rule)
-> Maybe (NonEmpty Rule)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Rule] -> [Rule]
f ([Rule] -> [Rule])
-> (Maybe (NonEmpty Rule) -> [Rule])
-> Maybe (NonEmpty Rule)
-> [Rule]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Rule]
-> (NonEmpty Rule -> [Rule]) -> Maybe (NonEmpty Rule) -> [Rule]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] NonEmpty Rule -> [Element (NonEmpty Rule)]
NonEmpty Rule -> [Rule]
forall t. Container t => t -> [Element t]
toList) OptimizationStage
prio (Map OptimizationStage (NonEmpty Rule)
 -> Map OptimizationStage (NonEmpty Rule))
-> (Ruleset -> Map OptimizationStage (NonEmpty Rule))
-> Ruleset
-> Map OptimizationStage (NonEmpty Rule)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ruleset -> Map OptimizationStage (NonEmpty Rule)
unRuleset