-- |
-- Module      : Data.Express.Map
-- Copyright   : (c) 2019-2021 Rudy Matela
-- Maintainer  : Rudy Matela <rudy@matela.com.br>
--
-- Utilities for mapping or transforming 'Expr's.
module Data.Express.Map
( mapValues
, mapVars
, mapConsts
, mapSubexprs
, (//-)
, (//)
, renameVarsBy
)
where

import Data.Express.Core
import Data.Express.Utils.List
import Data.Maybe (fromMaybe)

-- | /O(n*m)/.
-- Applies a function to all terminal values in an expression.
-- (cf. '//-')
--
-- Given that:
--
-- > > let zero  = val (0 :: Int)
-- > > let one   = val (1 :: Int)
-- > > let two   = val (2 :: Int)
-- > > let three = val (3 :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :\$ xx :\$ yy
-- > > let intToZero e = if typ e == typ zero then zero else e
--
-- Then:
--
-- > > one -+- (two -+- three)
-- > 1 + (2 + 3) :: Int
--
-- > > mapValues intToZero \$ one -+- (two -+- three)
-- > 0 + (0 + 0) :: Integer
--
-- Given that the argument function is /O(m)/, this function is /O(n*m)/.
mapValues :: (Expr -> Expr) -> Expr -> Expr
mapValues :: (Expr -> Expr) -> Expr -> Expr
mapValues Expr -> Expr
f  =  Expr -> Expr
m
where
m :: Expr -> Expr
m (Expr
e1 :\$ Expr
e2)  =  Expr -> Expr
m Expr
e1 Expr -> Expr -> Expr
:\$ Expr -> Expr
m Expr
e2
m Expr
e           =  Expr -> Expr
f Expr
e

-- | /O(n*m)/.
-- Applies a function to all variables in an expression.
--
-- Given that:
--
-- > > let primeify e = if isVar e
-- > |                  then case e of (Value n d) -> Value (n ++ "'") d
-- > |                  else e
-- > > let xx = var "x" (undefined :: Int)
-- > > let yy = var "y" (undefined :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :\$ xx :\$ yy
--
-- Then:
--
-- > > xx -+- yy
-- > x + y :: Int
--
-- > > primeify xx
-- > x' :: Int
--
-- > > mapVars primeify \$ xx -+- yy
-- > x' + y' :: Int
--
-- > > mapVars (primeify . primeify) \$ xx -+- yy
-- > x'' + y'' :: Int
--
-- Given that the argument function is /O(m)/, this function is /O(n*m)/.
mapVars :: (Expr -> Expr) -> Expr -> Expr
mapVars :: (Expr -> Expr) -> Expr -> Expr
mapVars Expr -> Expr
f  =  (Expr -> Expr) -> Expr -> Expr
mapValues Expr -> Expr
f'
where
f' :: Expr -> Expr
f' Expr
e  =  if Expr -> Bool
isVar Expr
e
then Expr -> Expr
f Expr
e
else Expr
e

-- | /O(n*m)/.
-- Applies a function to all terminal constants in an expression.
--
-- Given that:
--
-- > > let one   = val (1 :: Int)
-- > > let two   = val (2 :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :\$ xx :\$ yy
-- > > let intToZero e = if typ e == typ zero then zero else e
--
-- Then:
--
-- > > one -+- (two -+- xx)
-- > 1 + (2 + x) :: Int
--
-- > > mapConsts intToZero (one -+- (two -+- xx))
-- > 0 + (0 + x) :: Integer
--
-- Given that the argument function is /O(m)/, this function is /O(n*m)/.
mapConsts :: (Expr -> Expr) -> Expr -> Expr
mapConsts :: (Expr -> Expr) -> Expr -> Expr
mapConsts Expr -> Expr
f  =  (Expr -> Expr) -> Expr -> Expr
mapValues Expr -> Expr
f'
where
f' :: Expr -> Expr
f' Expr
e  =  if Expr -> Bool
isConst Expr
e
then Expr -> Expr
f Expr
e
else Expr
e

-- | /O(n*m)/.
-- Substitute subexpressions of an expression using the given function.
-- Outer expressions have more precedence than inner expressions.
-- (cf. '//')
--
-- With:
--
-- > > let xx = var "x" (undefined :: Int)
-- > > let yy = var "y" (undefined :: Int)
-- > > let zz = var "z" (undefined :: Int)
-- > > let plus = value "+" ((+) :: Int->Int->Int)
-- > > let times = value "*" ((*) :: Int->Int->Int)
-- > > let xx -+- yy = plus :\$ xx :\$ yy
-- > > let xx -*- yy = times :\$ xx :\$ yy
--
-- > > let pluswap (o :\$ xx :\$ yy) | o == plus = Just \$ o :\$ yy :\$ xx
-- > |     pluswap _                           = Nothing
--
-- Then:
--
-- > > mapSubexprs pluswap \$ (xx -*- yy) -+- (yy -*- zz)
-- > y * z + x * y :: Int
--
-- > > mapSubexprs pluswap \$ (xx -+- yy) -*- (yy -+- zz)
-- > (y + x) * (z + y) :: Int
--
-- Substitutions do not stack, in other words
-- a replaced expression or its subexpressions are not further replaced:
--
-- > > mapSubexprs pluswap \$ (xx -+- yy) -+- (yy -+- zz)
-- > (y + z) + (x + y) :: Int
--
-- Given that the argument function is /O(m)/, this function is /O(n*m)/.
mapSubexprs :: (Expr -> Maybe Expr) -> Expr -> Expr
mapSubexprs :: (Expr -> Maybe Expr) -> Expr -> Expr
mapSubexprs Expr -> Maybe Expr
f  =  Expr -> Expr
m
where
m :: Expr -> Expr
m Expr
e  =  Expr -> Maybe Expr -> Expr
forall a. a -> Maybe a -> a
fromMaybe Expr
e' (Expr -> Maybe Expr
f Expr
e)
where
e' :: Expr
e'  =  case Expr
e of
e1 :\$ e2 -> Expr -> Expr
m Expr
e1 Expr -> Expr -> Expr
:\$ Expr -> Expr
m Expr
e2
Expr
e -> Expr
e

-- | /O(n*m)/.
-- Substitute occurrences of values in an expression
-- from the given list of substitutions.
-- (cf. 'mapValues')
--
-- Given that:
--
-- > > let xx = var "x" (undefined :: Int)
-- > > let yy = var "y" (undefined :: Int)
-- > > let zz = var "z" (undefined :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :\$ xx :\$ yy
--
-- Then:
--
-- > > ((xx -+- yy) -+- (yy -+- zz)) //- [(xx, yy), (zz, yy)]
-- > (y + y) + (y + y) :: Int
--
-- > > ((xx -+- yy) -+- (yy -+- zz)) //- [(yy, yy -+- zz)]
-- > (x + (y + z)) + ((y + z) + z) :: Int
--
-- This function does not work for substituting non-terminal subexpressions:
--
-- > > (xx -+- yy) //- [(xx -+- yy, zz)]
-- > x + y :: Int
--
-- Please use the slower '//' if you want the above replacement to work.
--
-- Replacement happens only once:
--
-- > > xx //- [(xx,yy), (yy,zz)]
-- > y :: Int
--
-- Given that the argument list has length /m/,
-- this function is /O(n*m)/.
(//-) :: Expr -> [(Expr,Expr)] -> Expr
Expr
e //- :: Expr -> [(Expr, Expr)] -> Expr
//- [(Expr, Expr)]
s  =  (Expr -> Expr) -> Expr -> Expr
mapValues ((Expr -> [(Expr, Expr)] -> Expr) -> [(Expr, Expr)] -> Expr -> Expr
forall a b c. (a -> b -> c) -> b -> a -> c
flip Expr -> [(Expr, Expr)] -> Expr
forall a. Eq a => a -> [(a, a)] -> a
lookupId [(Expr, Expr)]
s) Expr
e

-- | /O(n*n*m)/.
-- Substitute subexpressions in an expression
-- from the given list of substitutions.
-- (cf. 'mapSubexprs').
--
-- Please consider using '//-' if you are replacing just terminal values
-- as it is faster.
--
-- Given that:
--
-- > > let xx = var "x" (undefined :: Int)
-- > > let yy = var "y" (undefined :: Int)
-- > > let zz = var "z" (undefined :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :\$ xx :\$ yy
--
-- Then:
--
-- > > ((xx -+- yy) -+- (yy -+- zz)) // [(xx -+- yy, yy), (yy -+- zz, yy)]
-- > y + y :: Int
--
-- > > ((xx -+- yy) -+- zz) // [(xx -+- yy, zz), (zz, xx -+- yy)]
-- > z + (x + y) :: Int
--
-- Replacement happens only once with outer expressions
-- having more precedence than inner expressions.
--
-- > > (xx -+- yy) // [(yy,xx), (xx -+- yy,zz), (zz,xx)]
-- > z :: Int
--
-- Given that the argument list has length /m/, this function is /O(n*n*m)/.
-- Remember that since /n/ is the size of an expression,
-- comparing two expressions is /O(n)/ in the worst case,
-- and we may need to compare with /n/ subexpressions in the worst case.
(//) :: Expr -> [(Expr,Expr)] -> Expr
Expr
e // :: Expr -> [(Expr, Expr)] -> Expr
// [(Expr, Expr)]
s  =  (Expr -> Maybe Expr) -> Expr -> Expr
mapSubexprs ((Expr -> [(Expr, Expr)] -> Maybe Expr)
-> [(Expr, Expr)] -> Expr -> Maybe Expr
forall a b c. (a -> b -> c) -> b -> a -> c
flip Expr -> [(Expr, Expr)] -> Maybe Expr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(Expr, Expr)]
s) Expr
e

-- | Rename variables in an 'Expr'.
--
-- > > renameVarsBy (++ "'") (xx -+- yy)
-- > x' + y' :: Int
--
-- > > renameVarsBy (++ "'") (yy -+- (zz -+- xx))
-- > (y' + (z' + x')) :: Int
--
-- > > renameVarsBy (++ "1") (abs' xx)
-- > abs x1 :: Int
--
-- > > renameVarsBy (++ "2") \$ abs' (xx -+- yy)
-- > abs (x2 + y2) :: Int
--
-- NOTE: this will affect holes!
renameVarsBy :: (String -> String) -> Expr -> Expr
renameVarsBy :: (String -> String) -> Expr -> Expr
renameVarsBy String -> String
f = (Expr -> Expr) -> Expr -> Expr
mapValues Expr -> Expr
f'
where
f' :: Expr -> Expr
f' (Value (Char
'_':String
n) Dynamic
t) = String -> Dynamic -> Expr
Value (Char
'_'Char -> String -> String
forall a. a -> [a] -> [a]
:String -> String
f String
n) Dynamic
t
f' Expr
e = Expr
e