module Language.PureScript.CodeGen.JS.Optimizer.Inliner
( inlineVariables
, inlineCommonValues
, inlineOperator
, inlineCommonOperators
, inlineFnComposition
, etaConvert
, unThunk
, evaluateIifes
) where
import Prelude ()
import Prelude.Compat
import Control.Monad.Supply.Class (MonadSupply, freshName)
import Data.Maybe (fromMaybe)
import Language.PureScript.CodeGen.JS.AST
import Language.PureScript.CodeGen.JS.Common
import Language.PureScript.Names
import Language.PureScript.CodeGen.JS.Optimizer.Common
import qualified Language.PureScript.Constants as C
shouldInline :: JS -> Bool
shouldInline (JSVar _) = True
shouldInline (JSNumericLiteral _) = True
shouldInline (JSStringLiteral _) = True
shouldInline (JSBooleanLiteral _) = True
shouldInline (JSAccessor _ val) = shouldInline val
shouldInline (JSIndexer index val) = shouldInline index && shouldInline val
shouldInline _ = False
etaConvert :: JS -> JS
etaConvert = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSBlock [JSReturn (JSApp (JSFunction Nothing idents block@(JSBlock body)) args)])
| all shouldInline args &&
not (any (`isRebound` block) (map JSVar idents)) &&
not (any (`isRebound` block) args)
= JSBlock (map (replaceIdents (zip idents args)) body)
convert (JSFunction Nothing [] (JSBlock [JSReturn (JSApp fn [])])) = fn
convert js = js
unThunk :: JS -> JS
unThunk = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSBlock []) = JSBlock []
convert (JSBlock jss) =
case last jss of
JSReturn (JSApp (JSFunction Nothing [] (JSBlock body)) []) -> JSBlock $ init jss ++ body
_ -> JSBlock jss
convert js = js
evaluateIifes :: JS -> JS
evaluateIifes = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSApp (JSFunction Nothing [] (JSBlock [JSReturn ret])) []) = ret
convert js = js
inlineVariables :: JS -> JS
inlineVariables = everywhereOnJS $ removeFromBlock go
where
go :: [JS] -> [JS]
go [] = []
go (JSVariableIntroduction var (Just js) : sts)
| shouldInline js && not (any (isReassigned var) sts) && not (any (isRebound js) sts) && not (any (isUpdated var) sts) =
go (map (replaceIdent var js) sts)
go (s:sts) = s : go sts
inlineCommonValues :: JS -> JS
inlineCommonValues = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSApp fn [dict])
| isDict' (semiringNumber ++ semiringInt) dict && isFn' fnZero fn = JSNumericLiteral (Left 0)
| isDict' (semiringNumber ++ semiringInt) dict && isFn' fnOne fn = JSNumericLiteral (Left 1)
| isDict' boundedBoolean dict && isFn' fnBottom fn = JSBooleanLiteral False
| isDict' boundedBoolean dict && isFn' fnTop fn = JSBooleanLiteral True
convert (JSApp (JSApp (JSApp fn [dict]) [x]) [y])
| isDict' semiringInt dict && isFn' fnAdd fn = intOp Add x y
| isDict' semiringInt dict && isFn' fnMultiply fn = intOp Multiply x y
| isDict' moduloSemiringInt dict && isFn' fnDivide fn = intOp Divide x y
| isDict' ringInt dict && isFn' fnSubtract fn = intOp Subtract x y
convert other = other
fnZero = [(C.prelude, C.zero), (C.dataSemiring, C.zero)]
fnOne = [(C.prelude, C.one), (C.dataSemiring, C.one)]
fnBottom = [(C.prelude, C.bottom), (C.dataBounded, C.bottom)]
fnTop = [(C.prelude, C.top), (C.dataBounded, C.top)]
fnAdd = [(C.prelude, (C.+)), (C.prelude, (C.add)), (C.dataSemiring, (C.+)), (C.dataSemiring, (C.add))]
fnDivide = [(C.prelude, (C./)), (C.prelude, (C.div)), (C.dataModuloSemiring, C.div)]
fnMultiply = [(C.prelude, (C.*)), (C.prelude, (C.mul)), (C.dataSemiring, (C.*)), (C.dataSemiring, (C.mul))]
fnSubtract = [(C.prelude, (C.-)), (C.prelude, C.sub), (C.dataRing, C.sub)]
intOp op x y = JSBinary BitwiseOr (JSBinary op x y) (JSNumericLiteral (Left 0))
inlineOperator :: (String, String) -> (JS -> JS -> JS) -> JS -> JS
inlineOperator (m, op) f = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSApp (JSApp op' [x]) [y]) | isOp op' = f x y
convert other = other
isOp (JSAccessor longForm (JSVar m')) = m == m' && longForm == identToJs (Op op)
isOp (JSIndexer (JSStringLiteral op') (JSVar m')) = m == m' && op == op'
isOp _ = False
inlineCommonOperators :: JS -> JS
inlineCommonOperators = applyAll $
[ binary semiringNumber opAdd Add
, binary semiringNumber opMul Multiply
, binary ringNumber opSub Subtract
, unary ringNumber opNegate Negate
, binary ringInt opSub Subtract
, unary ringInt opNegate Negate
, binary moduloSemiringNumber opDiv Divide
, binary moduloSemiringInt opMod Modulus
, binary eqNumber opEq EqualTo
, binary eqNumber opNotEq NotEqualTo
, binary eqInt opEq EqualTo
, binary eqInt opNotEq NotEqualTo
, binary eqString opEq EqualTo
, binary eqString opNotEq NotEqualTo
, binary eqChar opEq EqualTo
, binary eqChar opNotEq NotEqualTo
, binary eqBoolean opEq EqualTo
, binary eqBoolean opNotEq NotEqualTo
, binary ordBoolean opLessThan LessThan
, binary ordBoolean opLessThanOrEq LessThanOrEqualTo
, binary ordBoolean opGreaterThan GreaterThan
, binary ordBoolean opGreaterThanOrEq GreaterThanOrEqualTo
, binary ordChar opLessThan LessThan
, binary ordChar opLessThanOrEq LessThanOrEqualTo
, binary ordChar opGreaterThan GreaterThan
, binary ordChar opGreaterThanOrEq GreaterThanOrEqualTo
, binary ordInt opLessThan LessThan
, binary ordInt opLessThanOrEq LessThanOrEqualTo
, binary ordInt opGreaterThan GreaterThan
, binary ordInt opGreaterThanOrEq GreaterThanOrEqualTo
, binary ordNumber opLessThan LessThan
, binary ordNumber opLessThanOrEq LessThanOrEqualTo
, binary ordNumber opGreaterThan GreaterThan
, binary ordNumber opGreaterThanOrEq GreaterThanOrEqualTo
, binary ordString opLessThan LessThan
, binary ordString opLessThanOrEq LessThanOrEqualTo
, binary ordString opGreaterThan GreaterThan
, binary ordString opGreaterThanOrEq GreaterThanOrEqualTo
, binary semigroupString opAppend Add
, binary booleanAlgebraBoolean opConj And
, binary booleanAlgebraBoolean opDisj Or
, unary booleanAlgebraBoolean opNot Not
, binary' C.dataIntBits (C..|.) BitwiseOr
, binary' C.dataIntBits (C..&.) BitwiseAnd
, binary' C.dataIntBits (C..^.) BitwiseXor
, binary' C.dataIntBits C.shl ShiftLeft
, binary' C.dataIntBits C.shr ShiftRight
, binary' C.dataIntBits C.zshr ZeroFillShiftRight
, unary' C.dataIntBits C.complement BitwiseNot
] ++
[ fn | i <- [0..10], fn <- [ mkFn i, runFn i ] ]
where
binary :: [(String, String)] -> [(String, String)] -> BinaryOperator -> JS -> JS
binary dict fns op = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSApp (JSApp (JSApp fn [dict']) [x]) [y]) | isDict' dict dict' && isFn' fns fn = JSBinary op x y
convert other = other
binary' :: String -> String -> BinaryOperator -> JS -> JS
binary' moduleName opString op = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSApp (JSApp fn [x]) [y]) | isFn (moduleName, opString) fn = JSBinary op x y
convert other = other
unary :: [(String, String)] -> [(String, String)] -> UnaryOperator -> JS -> JS
unary dicts fns op = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSApp (JSApp fn [dict']) [x]) | isDict' dicts dict' && isFn' fns fn = JSUnary op x
convert other = other
unary' :: String -> String -> UnaryOperator -> JS -> JS
unary' moduleName fnName op = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSApp fn [x]) | isFn (moduleName, fnName) fn = JSUnary op x
convert other = other
mkFn :: Int -> JS -> JS
mkFn 0 = everywhereOnJS convert
where
convert :: JS -> JS
convert (JSApp mkFnN [JSFunction Nothing [_] (JSBlock js)]) | isNFn C.mkFn 0 mkFnN =
JSFunction Nothing [] (JSBlock js)
convert other = other
mkFn n = everywhereOnJS convert
where
convert :: JS -> JS
convert orig@(JSApp mkFnN [fn]) | isNFn C.mkFn n mkFnN =
case collectArgs n [] fn of
Just (args, js) -> JSFunction Nothing args (JSBlock js)
Nothing -> orig
convert other = other
collectArgs :: Int -> [String] -> JS -> Maybe ([String], [JS])
collectArgs 1 acc (JSFunction Nothing [oneArg] (JSBlock js)) | length acc == n 1 = Just (reverse (oneArg : acc), js)
collectArgs m acc (JSFunction Nothing [oneArg] (JSBlock [JSReturn ret])) = collectArgs (m 1) (oneArg : acc) ret
collectArgs _ _ _ = Nothing
isNFn :: String -> Int -> JS -> Bool
isNFn prefix n (JSVar name) = name == (prefix ++ show n)
isNFn prefix n (JSAccessor name (JSVar dataFunction)) | dataFunction == C.dataFunction = name == (prefix ++ show n)
isNFn _ _ _ = False
runFn :: Int -> JS -> JS
runFn n = everywhereOnJS convert
where
convert :: JS -> JS
convert js = fromMaybe js $ go n [] js
go :: Int -> [JS] -> JS -> Maybe JS
go 0 acc (JSApp runFnN [fn]) | isNFn C.runFn n runFnN && length acc == n = Just (JSApp fn acc)
go m acc (JSApp lhs [arg]) = go (m 1) (arg : acc) lhs
go _ _ _ = Nothing
inlineFnComposition :: (Applicative m, MonadSupply m) => JS -> m JS
inlineFnComposition = everywhereOnJSTopDownM convert
where
convert :: (MonadSupply m) => JS -> m JS
convert (JSApp (JSApp (JSApp (JSApp fn [dict']) [x]) [y]) [z])
| isFnCompose dict' fn = return $ JSApp x [JSApp y [z]]
| isFnComposeFlipped dict' fn = return $ JSApp y [JSApp x [z]]
convert (JSApp (JSApp (JSApp fn [dict']) [x]) [y])
| isFnCompose dict' fn = do
arg <- freshName
return $ JSFunction Nothing [arg] (JSBlock [JSReturn $ JSApp x [JSApp y [JSVar arg]]])
| isFnComposeFlipped dict' fn = do
arg <- freshName
return $ JSFunction Nothing [arg] (JSBlock [JSReturn $ JSApp y [JSApp x [JSVar arg]]])
convert other = return other
isFnCompose :: JS -> JS -> Bool
isFnCompose dict' fn = isDict' semigroupoidFn dict' && isFn' fnCompose fn
isFnComposeFlipped :: JS -> JS -> Bool
isFnComposeFlipped dict' fn = isDict' semigroupoidFn dict' && isFn' fnComposeFlipped fn
fnCompose :: [(String, String)]
fnCompose = [(C.prelude, C.compose), (C.prelude, (C.<<<)), (C.controlSemigroupoid, C.compose)]
fnComposeFlipped :: [(String, String)]
fnComposeFlipped = [(C.prelude, (C.>>>)), (C.controlSemigroupoid, C.composeFlipped)]
semiringNumber :: [(String, String)]
semiringNumber = [(C.prelude, C.semiringNumber), (C.dataSemiring, C.semiringNumber)]
semiringInt :: [(String, String)]
semiringInt = [(C.prelude, C.semiringInt), (C.dataSemiring, C.semiringInt)]
ringNumber :: [(String, String)]
ringNumber = [(C.prelude, C.ringNumber), (C.dataRing, C.ringNumber)]
ringInt :: [(String, String)]
ringInt = [(C.prelude, C.ringInt), (C.dataRing, C.ringInt)]
moduloSemiringNumber :: [(String, String)]
moduloSemiringNumber = [(C.prelude, C.moduloSemiringNumber), (C.dataModuloSemiring, C.moduloSemiringNumber)]
moduloSemiringInt :: [(String, String)]
moduloSemiringInt = [(C.prelude, C.moduloSemiringInt), (C.dataModuloSemiring, C.moduloSemiringInt)]
eqNumber :: [(String, String)]
eqNumber = [(C.prelude, C.eqNumber), (C.dataEq, C.eqNumber)]
eqInt :: [(String, String)]
eqInt = [(C.prelude, C.eqInt), (C.dataEq, C.eqInt)]
eqString :: [(String, String)]
eqString = [(C.prelude, C.eqString), (C.dataEq, C.eqString)]
eqChar :: [(String, String)]
eqChar = [(C.prelude, C.eqChar), (C.dataEq, C.eqChar)]
eqBoolean :: [(String, String)]
eqBoolean = [(C.prelude, C.eqBoolean), (C.dataEq, C.eqBoolean)]
ordBoolean :: [(String, String)]
ordBoolean = [(C.prelude, C.ordBoolean), (C.dataOrd, C.ordBoolean)]
ordNumber :: [(String, String)]
ordNumber = [(C.prelude, C.ordNumber), (C.dataOrd, C.ordNumber)]
ordInt :: [(String, String)]
ordInt = [(C.prelude, C.ordInt), (C.dataOrd, C.ordInt)]
ordString :: [(String, String)]
ordString = [(C.prelude, C.ordString), (C.dataOrd, C.ordString)]
ordChar :: [(String, String)]
ordChar = [(C.prelude, C.ordChar), (C.dataOrd, C.ordChar)]
semigroupString :: [(String, String)]
semigroupString = [(C.prelude, C.semigroupString), (C.dataSemigroup, C.semigroupString)]
boundedBoolean :: [(String, String)]
boundedBoolean = [(C.prelude, C.boundedBoolean), (C.dataBounded, C.boundedBoolean)]
booleanAlgebraBoolean :: [(String, String)]
booleanAlgebraBoolean = [(C.prelude, C.booleanAlgebraBoolean), (C.dataBooleanAlgebra, C.booleanAlgebraBoolean)]
semigroupoidFn :: [(String, String)]
semigroupoidFn = [(C.prelude, C.semigroupoidFn), (C.controlSemigroupoid, C.semigroupoidFn)]
opAdd :: [(String, String)]
opAdd = [(C.prelude, (C.+)), (C.prelude, C.add), (C.dataSemiring, C.add)]
opMul :: [(String, String)]
opMul = [(C.prelude, (C.*)), (C.prelude, C.mul), (C.dataSemiring, C.mul)]
opEq :: [(String, String)]
opEq = [(C.prelude, (C.==)), (C.prelude, C.eq), (C.dataEq, C.eq)]
opNotEq :: [(String, String)]
opNotEq = [(C.prelude, (C./=)), (C.dataEq, C.notEq)]
opLessThan :: [(String, String)]
opLessThan = [(C.prelude, (C.<)), (C.dataOrd, C.lessThan)]
opLessThanOrEq :: [(String, String)]
opLessThanOrEq = [(C.prelude, (C.<=)), (C.dataOrd, C.lessThanOrEq)]
opGreaterThan :: [(String, String)]
opGreaterThan = [(C.prelude, (C.>)), (C.dataOrd, C.greaterThan)]
opGreaterThanOrEq :: [(String, String)]
opGreaterThanOrEq = [(C.prelude, (C.>=)), (C.dataOrd, C.greaterThanOrEq)]
opAppend :: [(String, String)]
opAppend = [(C.prelude, (C.<>)), (C.prelude, (C.++)), (C.prelude, C.append), (C.dataSemigroup, C.append)]
opSub :: [(String, String)]
opSub = [(C.prelude, (C.-)), (C.prelude, C.sub), (C.dataRing, C.sub)]
opNegate :: [(String, String)]
opNegate = [(C.prelude, C.negate), (C.dataRing, C.negate)]
opDiv :: [(String, String)]
opDiv = [(C.prelude, (C./)), (C.prelude, C.div), (C.dataModuloSemiring, C.div)]
opMod :: [(String, String)]
opMod = [(C.prelude, C.mod), (C.dataModuloSemiring, C.mod)]
opConj :: [(String, String)]
opConj = [(C.prelude, (C.&&)), (C.prelude, C.conj), (C.dataBooleanAlgebra, C.conj)]
opDisj :: [(String, String)]
opDisj = [(C.prelude, (C.||)), (C.prelude, C.disj), (C.dataBooleanAlgebra, C.disj)]
opNot :: [(String, String)]
opNot = [(C.prelude, C.not), (C.dataBooleanAlgebra, C.not)]