{-# LANGUAGE TypeFamilies #-} module Feldspar.Compiler.Plugins.ConstantFolding where import Feldspar.Compiler.PluginArchitecture data ConstantFolding = ConstantFolding instance Plugin ConstantFolding where type ExternalInfo ConstantFolding = () executePlugin ConstantFolding _ procedure = fst $ executeTransformationPhase ConstantFolding () procedure instance TransformationPhase ConstantFolding where type From ConstantFolding = () type To ConstantFolding = () type Downwards ConstantFolding = () type Upwards ConstantFolding = () transformFunctionCall ConstantFolding _ _ (InfosFromFunctionCallParts funData _) = case roleOfFunctionToCall $ funData of InfixOp -> case nameOfFunctionToCall $ funData of "+" -> elimParamIf (isConstIntN 0) True funCall "-" -> elimParamIf (isConstIntN 0) False funCall "*" -> elimParamIf (isConstIntN 1) True funCall _ -> FunctionCallExpression funCall _ -> FunctionCallExpression funCall where funCall = FunctionCall (funData) () isConstIntN n (ConstantExpression (IntConstant (IntConstantType i _))) = n == i isConstIntN _ _ = False elimParamIf pred flippable funCall@(FunctionCall (FunctionCallData InfixOp _ _ (x:xs)) _) | pred (head xs) = x | flippable && pred x = head xs | otherwise = FunctionCallExpression funCall elimParamIf _ _ funCall = FunctionCallExpression funCall