{-# LANGUAGE TypeFamilies #-}

module Feldspar.Compiler.Imperative.Plugin.ConstantFolding where

import Feldspar.Transformation

data ConstantFolding = ConstantFolding

instance Plugin ConstantFolding where
  type ExternalInfo ConstantFolding = ()
  executePlugin ConstantFolding _ procedure = result $ transform ConstantFolding () () procedure

instance Transformation ConstantFolding where
    type From ConstantFolding   = ()
    type To ConstantFolding     = ()
    type Down ConstantFolding   = ()
    type Up ConstantFolding     = ()
    type State ConstantFolding  = ()

instance Transformable ConstantFolding Expression where
    transform t s d f@(FunctionCall _ _ _ _ _ _) = case funRole f' of
        InfixOp -> case funCallName f' of
            "+"     -> tr' $ elimParamIf (isConstIntN 0) True  $ result tr
            "-"     -> tr' $ elimParamIf (isConstIntN 0) False $ result tr
            "*"     -> tr' $ elimParamIf (isConstIntN 1) True  $ result tr
            _       -> tr
        _       -> tr
        where
            tr = defaultTransform t s d f
            tr' x = tr {result = x}
            f' = result tr
            isConstIntN n (ConstExpr (IntConst i _ _) _) = n == i
            isConstIntN _ _ = False

            elimParamIf pred flippable funCall@(FunctionCall _ _ InfixOp (x:xs) _ _)
                | pred (head xs)      = x
                | flippable && pred x = head xs
                | otherwise           = funCall
            elimParamIf _ _ funCall   = funCall
    transform t s d e = defaultTransform t s d e