module Dvda.Expr ( Expr(..)
, GExpr(..)
, Nums(..)
, Fractionals(..)
, Floatings(..)
, Sym(..)
, sym
, symDependent
, symDependentN
, const'
, extractLinearPart
, getConst
, substitute
, sketchySubstitute
, foldExpr
, fromNeg
) where
import Control.Applicative ( (<$>), pure )
import Data.Hashable ( Hashable(..), hash )
import Data.Ratio ( (%) )
import GHC.Generics ( Generic )
import Data.Monoid ( mempty )
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import qualified Dvda.HashMap as HM
commutativeMul :: Bool
commutativeMul = True
commutativeAdd :: Bool
commutativeAdd = True
data Sym =
Sym String
| SymDependent String Int Sym
deriving (Eq, Ord, Generic)
instance Show Sym where
showsPrec d (Sym name) = showParen (d >= 9) $ showString name
showsPrec d (SymDependent name k s) =
showParen (d >= 9) $
showString $ name ++ replicate k '\'' ++ "(" ++ show s ++ ")"
data Expr a where
ESym :: Sym -> Expr a
EConst :: a -> Expr a
ENum :: Num a => Nums (Expr a) -> Expr a
EFractional :: Fractional a => Fractionals (Expr a) -> Expr a
EFloating :: Floating a => Floatings (Expr a) -> Expr a
data Nums a = Mul a a
| Add a a
| Sub a a
| Negate a
| Abs a
| Signum a
| FromInteger Integer
deriving (Ord, Generic, Functor, F.Foldable, T.Traversable)
data Fractionals a = Div a a
| FromRational Rational
deriving (Eq, Ord, Generic, Functor, F.Foldable, T.Traversable)
data Floatings a = Pow a a
| LogBase a a
| Exp a
| Log a
| Sin a
| Cos a
| Tan a
| ASin a
| ATan a
| ACos a
| Sinh a
| Cosh a
| Tanh a
| ASinh a
| ATanh a
| ACosh a
deriving (Eq, Ord, Generic, Functor, F.Foldable, T.Traversable)
showsInfixBinary :: (Show a, Show b) => Int -> Int -> String -> a -> b -> ShowS
showsInfixBinary d prec op u v = showParen (d >= prec) $
showsPrec prec u .
showString op .
showsPrec prec v
showsUnary :: Show a => Int -> Int -> String -> a -> ShowS
showsUnary d prec op u = showParen (d >= prec) $
showString op .
showsPrec prec u
instance Show a => Show (Nums a) where
showsPrec d (Mul x y) = showsInfixBinary d 7 " * " x y
showsPrec d (Add x y) = showsInfixBinary d 6 " + " x y
showsPrec d (Sub x y) = showsInfixBinary d 6 " - " x y
showsPrec d (Negate x) = showsUnary d 7 "-" x
showsPrec d (Abs x) = showsUnary d 10 "abs" x
showsPrec d (Signum x) = showsUnary d 10 "signum" x
showsPrec d (FromInteger k) = showParen (d >= 9) $ showString (show k)
instance Show a => Show (Fractionals a) where
showsPrec d (Div x y) = showsInfixBinary d 7 " / " x y
showsPrec d (FromRational r) = showParen (d >= 9) $ showString $ show (fromRational r :: Double)
instance Show a => Show (Floatings a) where
showsPrec d (Pow x y) = showsInfixBinary d 8 " ** " x y
showsPrec d (LogBase x y) = showParen (d > 10) $ showString $ "logBase(" ++ show x ++ ", " ++ show y ++ ")"
showsPrec d (Exp x) = showsUnary d 10 "exp" x
showsPrec d (Log x) = showsUnary d 10 "log" x
showsPrec d (Sin x) = showsUnary d 10 "sin" x
showsPrec d (Cos x) = showsUnary d 10 "cos" x
showsPrec d (Tan x) = showsUnary d 10 "tan" x
showsPrec d (ASin x) = showsUnary d 10 "asin" x
showsPrec d (ATan x) = showsUnary d 10 "atan" x
showsPrec d (ACos x) = showsUnary d 10 "acos" x
showsPrec d (Sinh x) = showsUnary d 10 "sinh" x
showsPrec d (Cosh x) = showsUnary d 10 "cosh" x
showsPrec d (Tanh x) = showsUnary d 10 "tanh" x
showsPrec d (ASinh x) = showsUnary d 10 "asinh" x
showsPrec d (ATanh x) = showsUnary d 10 "atanh" x
showsPrec d (ACosh x) = showsUnary d 10 "acosh" x
instance Show a => Show (Expr a) where
showsPrec d (ESym s) = showParen (d > 9) $
showString (show s)
showsPrec d (EConst x) = showParen (d >= 9) $ showString (show x)
showsPrec d (ENum x) = showsPrec d x
showsPrec d (EFractional x) = showsPrec d x
showsPrec d (EFloating x) = showsPrec d x
instance Eq a => Eq (Expr a) where
(==) (ESym x) (ESym y) = x == y
(==) (EConst x) (EConst y) = x == y
(==) (ENum x) (ENum y) = x == y
(==) (EFractional x) (EFractional y) = x == y
(==) (EFloating x) (EFloating y) = x == y
(==) _ _ = False
instance Eq a => Eq (Nums a) where
(Mul x0 y0) == (Mul x1 y1) = if commutativeMul
then (x0 == x1 && y0 == y1) || (x0 == y1 && x1 == y0)
else x0 == x1 && y0 == y1
(Add x0 y0) == (Add x1 y1) = if commutativeAdd
then (x0 == x1 && y0 == y1) || (x0 == y1 && x1 == y0)
else x0 == x1 && y0 == y1
(Sub x0 y0) == (Sub x1 y1) = x0 == x1 && y0 == y1
(Negate x) == (Negate y) = x == y
(Abs x) == (Abs y) = x == y
(Signum x) == (Signum y) = x == y
(FromInteger x) == (FromInteger y) = x == y
_ == _ = False
instance Hashable Sym
instance Hashable a => Hashable (Nums a) where
hashWithSalt s (Mul x y) = s `hashWithSalt` "Mul" `hashWithSalt` hx `hashWithSalt` hy
where
hx' = hash x
hy' = hash y
(hx, hy)
| commutativeMul = (min hx' hy', max hx' hy')
| otherwise = (hx', hy')
hashWithSalt s (Add x y) = s `hashWithSalt` "Add" `hashWithSalt` hx `hashWithSalt` hy
where
hx' = hash x
hy' = hash y
(hx, hy)
| commutativeAdd = (min hx' hy', max hx' hy')
| otherwise = (hx', hy')
hashWithSalt s (Sub x y) = s `hashWithSalt` "Sub" `hashWithSalt` x `hashWithSalt` y
hashWithSalt s (Negate x) = s `hashWithSalt` "Negate" `hashWithSalt` x
hashWithSalt s (Abs x) = s `hashWithSalt` "Abs" `hashWithSalt` x
hashWithSalt s (Signum x) = s `hashWithSalt` "Signum" `hashWithSalt` x
hashWithSalt s (FromInteger x) = s `hashWithSalt` "FromInteger" `hashWithSalt` x
instance Hashable a => Hashable (Fractionals a)
instance Hashable a => Hashable (Floatings a)
instance Hashable a => Hashable (Expr a) where
hashWithSalt s (ESym name) = s `hashWithSalt` "ESym" `hashWithSalt` name
hashWithSalt s (EConst x) = s `hashWithSalt` "EConst" `hashWithSalt` x
hashWithSalt s (ENum x) = s `hashWithSalt` "ENum" `hashWithSalt` x
hashWithSalt s (EFractional x) = s `hashWithSalt` "EFractional" `hashWithSalt` x
hashWithSalt s (EFloating x) = s `hashWithSalt` "EFloating" `hashWithSalt` x
fromNeg :: Expr a -> Maybe (Expr a)
fromNeg (ENum (Negate x)) = Just x
fromNeg (ENum (FromInteger k))
| k < 0 = Just (ENum (FromInteger (abs k)))
fromNeg (EFractional (FromRational r))
| r < 0 = Just (EFractional (FromRational (abs r)))
fromNeg _ = Nothing
instance (Eq a, Num a) => Num (Expr a) where
(*) x y
| isVal 0 x || isVal 0 y = 0
| isVal 1 x = y
| isVal 1 y = x
(*) (EConst x) (EConst y) = EConst (x*y)
(*) (ENum (FromInteger kx)) (ENum (FromInteger ky)) = ENum $ FromInteger (kx * ky)
(*) (EFractional (FromRational rx)) (EFractional (FromRational ry)) = EFractional $ FromRational (rx * ry)
(*) (EConst x) (ENum (FromInteger ky)) = EConst $ x * fromInteger ky
(*) (ENum (FromInteger kx)) (EConst y) = EConst $ fromInteger kx * y
(*) (EConst x) (EFractional (FromRational ry)) = EConst $ x * fromRational ry
(*) (EFractional (FromRational rx)) (EConst y) = EConst $ fromRational rx * y
(*) (ENum (FromInteger kx)) (EFractional (FromRational ry)) = EFractional $ FromRational (fromInteger kx * ry)
(*) (EFractional (FromRational rx)) (ENum (FromInteger ky)) = EFractional $ FromRational (rx * fromInteger ky)
(*) x y = case (fromNeg x, fromNeg y) of
(Just x', Just y') -> x' * y'
(Nothing, Just y') -> negate (x * y')
(Just x', Nothing) -> negate (x' * y )
_ -> ENum $ Mul x y
(+) x y
| isVal 0 x = y
| isVal 0 y = x
| x == negate y = 0
(+) (EConst x) (EConst y) = EConst (x+y)
(+) (ENum (FromInteger kx)) (ENum (FromInteger ky)) = ENum $ FromInteger (kx + ky)
(+) (EFractional (FromRational rx)) (EFractional (FromRational ry)) = EFractional $ FromRational (rx + ry)
(+) (EConst x) (ENum (FromInteger ky)) = EConst $ x + fromInteger ky
(+) (ENum (FromInteger kx)) (EConst y) = EConst $ fromInteger kx + y
(+) (EConst x) (EFractional (FromRational ry)) = EConst $ x + fromRational ry
(+) (EFractional (FromRational rx)) (EConst y) = EConst $ fromRational rx + y
(+) (ENum (FromInteger kx)) (EFractional (FromRational ry)) = EFractional $ FromRational (fromInteger kx + ry)
(+) (EFractional (FromRational rx)) (ENum (FromInteger ky)) = EFractional $ FromRational (rx + fromInteger ky)
(+) x y = case (fromNeg x, fromNeg y) of
(Just x', Just y') -> negate (x' + y')
(Nothing, Just y') -> x y'
(Just x', Nothing) -> y x'
_ -> ENum $ Add x y
() x y
| isVal 0 x = negate y
| isVal 0 y = x
| x == y = 0
() (EConst x) (EConst y) = EConst (xy)
() (ENum (FromInteger kx)) (ENum (FromInteger ky)) = ENum $ FromInteger (kx ky)
() (EFractional (FromRational rx)) (EFractional (FromRational ry)) = EFractional $ FromRational (rx ry)
() (EConst x) (ENum (FromInteger ky)) = EConst $ x fromInteger ky
() (ENum (FromInteger kx)) (EConst y) = EConst $ fromInteger kx y
() (EConst x) (EFractional (FromRational ry)) = EConst $ x fromRational ry
() (EFractional (FromRational rx)) (EConst y) = EConst $ fromRational rx y
() (ENum (FromInteger kx)) (EFractional (FromRational ry)) = EFractional $ FromRational (fromInteger kx ry)
() (EFractional (FromRational rx)) (ENum (FromInteger ky)) = EFractional $ FromRational (rx fromInteger ky)
() x y = case (fromNeg x, fromNeg y) of
(Just x', Just y') -> y' x'
(Nothing, Just y') -> x + y'
(Just x', Nothing) -> negate (x' + y)
_ -> ENum $ Sub x y
abs (EConst x) = EConst (abs x)
abs (ENum (FromInteger k)) = ENum (FromInteger (abs k))
abs (EFractional (FromRational r)) = EFractional (FromRational (abs r))
abs x = case fromNeg x of Nothing -> ENum (Abs x)
Just x' -> abs x'
negate (EConst x) = EConst (negate x)
negate (ENum (Sub x y)) = y x
negate (ENum (FromInteger k)) = ENum (FromInteger (negate k))
negate (EFractional (FromRational r)) = EFractional (FromRational (negate r))
negate (ENum (Negate x)) = x
negate x = ENum $ Negate x
signum (EConst x) = EConst (signum x)
signum (ENum (FromInteger k)) = ENum (FromInteger (signum k))
signum (EFractional (FromRational r)) = EFractional (FromRational (signum r))
signum x = ENum $ Signum x
fromInteger = ENum . FromInteger
instance (Eq a, Fractional a) => Fractional (Expr a) where
(/) x y
| isVal 0 y = error "Fractional (Expr a) divide by zero"
| isVal 0 x = 0
| isVal 1 y = x
| x == y = 1
(/) (EConst x) (EConst y) = EConst (x/y)
(/) (ENum (FromInteger kx)) (ENum (FromInteger ky)) = EFractional $ FromRational (kx % ky)
(/) (EFractional (FromRational rx)) (EFractional (FromRational ry)) = EFractional $ FromRational (rx / ry)
(/) (EConst x) (ENum (FromInteger ky)) = EConst $ x / fromInteger ky
(/) (ENum (FromInteger kx)) (EConst y) = EConst $ fromInteger kx / y
(/) (EConst x) (EFractional (FromRational ry)) = EConst $ x / fromRational ry
(/) (EFractional (FromRational rx)) (EConst y) = EConst $ fromRational rx / y
(/) (ENum (FromInteger kx)) (EFractional (FromRational ry)) = EFractional $ FromRational (fromInteger kx / ry)
(/) (EFractional (FromRational rx)) (ENum (FromInteger ky)) = EFractional $ FromRational (rx / fromInteger ky)
(/) x y = case (fromNeg x, fromNeg y) of
(Just x', Just y') -> x' / y'
(Nothing, Just y') -> negate (x / y')
(Just x', Nothing) -> negate (x' / y )
_ -> EFractional $ Div x y
fromRational = EFractional . FromRational
instance (Eq a, Floating a) => Floating (Expr a) where
pi = EConst pi
x ** 1 = x
0 ** 0 = error "Expr: 0 ** 0 indeterminate"
_ ** 0 = 1
x ** y = EFloating $ Pow x y
logBase x y = EFloating $ LogBase x y
exp = applyFloatingUn ( exp, Exp)
log = applyFloatingUn ( log, Log)
sin = applyFloatingUn ( sin, Sin)
cos = applyFloatingUn ( cos, Cos)
tan = applyFloatingUn ( tan, Tan)
asin = applyFloatingUn ( asin, ASin)
atan = applyFloatingUn ( atan, ATan)
acos = applyFloatingUn ( acos, ACos)
sinh = applyFloatingUn ( sinh, Sinh)
cosh = applyFloatingUn ( cosh, Cosh)
tanh = applyFloatingUn ( tanh, Tanh)
asinh = applyFloatingUn (asinh, ASinh)
atanh = applyFloatingUn (atanh, ATanh)
acosh = applyFloatingUn (acosh, ACosh)
applyFloatingUn :: Floating a => (t -> a, Expr t -> Floatings (Expr a)) -> Expr t -> Expr a
applyFloatingUn (f,_) (EConst x) = EConst (f x)
applyFloatingUn (f,_) (ENum (FromInteger x)) = EConst (f $ fromInteger x)
applyFloatingUn (f,_) (EFractional (FromRational x)) = EConst (f $ fromRational x)
applyFloatingUn (_,f) x = EFloating (f x)
data GExpr a b where
GSym :: Sym -> GExpr a b
GConst :: a -> GExpr a b
GNum :: Num a => Nums b -> GExpr a b
GFractional :: Fractional a => Fractionals b -> GExpr a b
GFloating :: Floating a => Floatings b -> GExpr a b
deriving instance (Show a, Show b) => Show (GExpr a b)
instance Functor (GExpr a) where
fmap _ (GSym s) = GSym s
fmap _ (GConst c) = GConst c
fmap f (GNum nums) = GNum (fmap f nums)
fmap f (GFractional fracs) = GFractional (fmap f fracs)
fmap f (GFloating floatings) = GFloating (fmap f floatings)
instance F.Foldable (GExpr a) where
foldMap _ (GSym _) = mempty
foldMap _ (GConst _) = mempty
foldMap f (GNum nums) = F.foldMap f nums
foldMap f (GFractional fracs) = F.foldMap f fracs
foldMap f (GFloating floatings) = F.foldMap f floatings
foldr _ z (GSym _) = z
foldr _ z (GConst _) = z
foldr f z (GNum nums) = F.foldr f z nums
foldr f z (GFractional fracs) = F.foldr f z fracs
foldr f z (GFloating floatings) = F.foldr f z floatings
instance T.Traversable (GExpr a) where
traverse _ (GSym s) = pure (GSym s)
traverse _ (GConst c) = pure (GConst c)
traverse f (GNum nums) = GNum <$> T.traverse f nums
traverse f (GFractional fracs) = GFractional <$> T.traverse f fracs
traverse f (GFloating floatings) = GFloating <$> T.traverse f floatings
deriving instance (Eq a, Eq b) => Eq (GExpr a b)
instance (Hashable a, Hashable b) => Hashable (GExpr a b) where
hashWithSalt s (GSym name) = s `hashWithSalt` "GSym" `hashWithSalt` name
hashWithSalt s (GConst x) = s `hashWithSalt` "GConst" `hashWithSalt` x
hashWithSalt s (GNum x) = s `hashWithSalt` "GNum" `hashWithSalt` x
hashWithSalt s (GFractional x) = s `hashWithSalt` "GFractional" `hashWithSalt` x
hashWithSalt s (GFloating x) = s `hashWithSalt` "GFloating" `hashWithSalt` x
substitute :: (Ord a, Hashable a, Show a) => Expr a -> [(Expr a, Expr a)] -> Expr a
substitute expr subList
| nonSymInputs /= [] = error $ "substitute got non-ESym input: " ++ show nonSymInputs
| otherwise = subs expr
where
isSym (ESym _) = True
isSym _ = False
nonSymInputs = filter (not . isSym . fst) subList
lookup' e = let hm = HM.fromList subList in
HM.lookupDefault e e hm
subs e@(ESym _) = lookup' e
subs e@(EConst _) = e
subs e@(ENum (FromInteger _)) = e
subs e@(EFractional (FromRational _)) = e
subs (ENum (Mul x y)) = subs x * subs y
subs (ENum (Add x y)) = subs x + subs y
subs (ENum (Sub x y)) = subs x subs y
subs (ENum (Negate x)) = negate (subs x)
subs (ENum (Abs x)) = abs (subs x)
subs (ENum (Signum x)) = signum (subs x)
subs (EFractional (Div x y)) = subs x / subs y
subs (EFloating (Pow x y)) = subs x ** subs y
subs (EFloating (LogBase x y)) = logBase (subs x) (subs y)
subs (EFloating (Exp x)) = exp (subs x)
subs (EFloating (Log x)) = log (subs x)
subs (EFloating (Sin x)) = sin (subs x)
subs (EFloating (Cos x)) = cos (subs x)
subs (EFloating (Tan x)) = tan (subs x)
subs (EFloating (ASin x)) = asin (subs x)
subs (EFloating (ATan x)) = atan (subs x)
subs (EFloating (ACos x)) = acos (subs x)
subs (EFloating (Sinh x)) = sinh (subs x)
subs (EFloating (Cosh x)) = cosh (subs x)
subs (EFloating (Tanh x)) = tanh (subs x)
subs (EFloating (ASinh x)) = asinh (subs x)
subs (EFloating (ATanh x)) = atanh (subs x)
subs (EFloating (ACosh x)) = acosh (subs x)
sketchySubstitute :: (Eq a, Hashable a, Show a) => Expr a -> [(Expr a, Expr a)] -> Expr a
sketchySubstitute expr subList
| nonSymInputs /= [] = error $ "substitute got non-ESym input: " ++ show nonSymInputs
| otherwise = subs expr
where
isSym (ESym _) = True
isSym _ = False
nonSymInputs = filter (not . isSym . fst) subList
lookup' e = let hm = HM.fromList subList in
HM.lookupDefault e e hm
subs e@(ESym _) = lookup' e
subs e@(EConst _) = e
subs e@(ENum (FromInteger _)) = e
subs e@(EFractional (FromRational _)) = e
subs (ENum (Mul x y)) = ENum (Mul (subs x) (subs y))
subs (ENum (Add x y)) = ENum (Add (subs x) (subs y))
subs (ENum (Sub x y)) = ENum (Sub (subs x) (subs y))
subs (ENum (Negate x)) = ENum (Negate (subs x))
subs (ENum (Abs x)) = ENum (Negate (subs x))
subs (ENum (Signum x)) = ENum (Signum (subs x))
subs (EFractional (Div x y)) = EFractional (Div (subs x) (subs y))
subs (EFloating (Pow x y)) = EFloating (Pow (subs x) (subs y))
subs (EFloating (LogBase x y)) = EFloating (LogBase (subs x) (subs y))
subs (EFloating (Exp x)) = EFloating (Exp (subs x))
subs (EFloating (Log x)) = EFloating (Log (subs x))
subs (EFloating (Sin x)) = EFloating (Sin (subs x))
subs (EFloating (Cos x)) = EFloating (Cos (subs x))
subs (EFloating (Tan x)) = EFloating (Tan (subs x))
subs (EFloating (ASin x)) = EFloating (ASin (subs x))
subs (EFloating (ATan x)) = EFloating (ATan (subs x))
subs (EFloating (ACos x)) = EFloating (ACos (subs x))
subs (EFloating (Sinh x)) = EFloating (Sinh (subs x))
subs (EFloating (Cosh x)) = EFloating (Cosh (subs x))
subs (EFloating (Tanh x)) = EFloating (Tanh (subs x))
subs (EFloating (ASinh x)) = EFloating (ASinh (subs x))
subs (EFloating (ATanh x)) = EFloating (ATanh (subs x))
subs (EFloating (ACosh x)) = EFloating (ACosh (subs x))
foldExpr :: (Expr a -> b -> b) -> b -> Expr a -> b
foldExpr f acc e@(ESym _) = f e acc
foldExpr f acc e@(EConst _) = f e acc
foldExpr f acc e@(EFractional (FromRational _)) = f e acc
foldExpr f acc e@(ENum (FromInteger _)) = f e acc
foldExpr f acc (ENum (Mul x y)) = foldExpr f (foldExpr f acc y) x
foldExpr f acc (ENum (Add x y)) = foldExpr f (foldExpr f acc y) x
foldExpr f acc (ENum (Sub x y)) = foldExpr f (foldExpr f acc y) x
foldExpr f acc (ENum (Negate x)) = foldExpr f acc x
foldExpr f acc (ENum (Abs x)) = foldExpr f acc x
foldExpr f acc (ENum (Signum x)) = foldExpr f acc x
foldExpr f acc (EFractional (Div x y)) = foldExpr f (foldExpr f acc y) x
foldExpr f acc (EFloating (Pow x y)) = foldExpr f (foldExpr f acc y) x
foldExpr f acc (EFloating (LogBase x y)) = foldExpr f (foldExpr f acc y) x
foldExpr f acc (EFloating (Exp x)) = foldExpr f acc x
foldExpr f acc (EFloating (Log x)) = foldExpr f acc x
foldExpr f acc (EFloating (Sin x)) = foldExpr f acc x
foldExpr f acc (EFloating (Cos x)) = foldExpr f acc x
foldExpr f acc (EFloating (Tan x)) = foldExpr f acc x
foldExpr f acc (EFloating (ASin x)) = foldExpr f acc x
foldExpr f acc (EFloating (ATan x)) = foldExpr f acc x
foldExpr f acc (EFloating (ACos x)) = foldExpr f acc x
foldExpr f acc (EFloating (Sinh x)) = foldExpr f acc x
foldExpr f acc (EFloating (Cosh x)) = foldExpr f acc x
foldExpr f acc (EFloating (Tanh x)) = foldExpr f acc x
foldExpr f acc (EFloating (ASinh x)) = foldExpr f acc x
foldExpr f acc (EFloating (ATanh x)) = foldExpr f acc x
foldExpr f acc (EFloating (ACosh x)) = foldExpr f acc x
sym :: String -> Expr a
sym = ESym . Sym
symDependent :: String -> Expr a -> Expr a
symDependent name s = symDependentN name s 0
symDependentN :: String -> Expr a -> Int -> Expr a
symDependentN name (ESym s) n = ESym (SymDependent name n s)
symDependentN _ _ _ = error "symDependent got non ESym dependency"
const' :: a -> Expr a
const' = EConst
isVal :: Eq a => a -> Expr a -> Bool
isVal v (EConst c) = v == c
isVal v (ENum (FromInteger k)) = v == fromInteger k
isVal v (EFractional (FromRational r)) = v == fromRational r
isVal _ _ = False
getConst :: Expr a -> Maybe a
getConst (EConst x) = Just x
getConst (ENum (FromInteger k)) = Just (fromInteger k)
getConst (EFractional (FromRational r)) = Just (fromRational r)
getConst _ = Nothing
extractLinearPart :: (Num a, Eq a, Show a) => Expr a -> Expr a -> (Expr a, a)
extractLinearPart e@(EConst _) _ = (e,0)
extractLinearPart e@(ENum (FromInteger _)) _ = (e,0)
extractLinearPart e@(EFractional (FromRational _)) _ = (e,0)
extractLinearPart e@(ESym _) arg
| e == arg = (0,1)
| otherwise = (e,0)
extractLinearPart (ENum (Add x y)) arg = (xNonlin+yNonlin, xLin+yLin)
where
(xNonlin,xLin) = extractLinearPart x arg
(yNonlin,yLin) = extractLinearPart y arg
extractLinearPart (ENum (Sub x y)) arg = (xNonlinyNonlin, xLinyLin)
where
(xNonlin,xLin) = extractLinearPart x arg
(yNonlin,yLin) = extractLinearPart y arg
extractLinearPart (ENum (Negate x)) arg = (xNonlin, xLin)
where
(xNonlin,xLin) = extractLinearPart x arg
extractLinearPart e@(ENum (Mul x y)) arg = case (getConst x, getConst y) of
(Nothing,Nothing) -> (e,0)
(Just cx, Nothing) -> let (yNl,yL) = extractLinearPart y arg in (EConst cx * yNl,cx*yL)
(Nothing, Just cy) -> let (xNl,xL) = extractLinearPart x arg in (xNl * EConst cy,xL*cy)
_ -> error $ "extractLinearPart got ENum (Mul x y) where x and y are both constants\n"++
"x: " ++ show x ++ "\ny: " ++ show y
extractLinearPart e@(EFractional (Div x y)) arg = case getConst y of
Nothing -> (e,0)
Just cy -> let (xNl,xL) = extractLinearPart x arg in (xNl/EConst cy,xL/cy)
extractLinearPart e@(ENum (Abs _)) _ = (e,0)
extractLinearPart e@(ENum (Signum _)) _ = (e,0)
extractLinearPart e@(EFloating (Pow _ _)) _ = (e,0)
extractLinearPart e@(EFloating (LogBase _ _)) _ = (e,0)
extractLinearPart e@(EFloating (Exp _)) _ = (e,0)
extractLinearPart e@(EFloating (Log _)) _ = (e,0)
extractLinearPart e@(EFloating (Sin _)) _ = (e,0)
extractLinearPart e@(EFloating (Cos _)) _ = (e,0)
extractLinearPart e@(EFloating (Tan _)) _ = (e,0)
extractLinearPart e@(EFloating (ASin _)) _ = (e,0)
extractLinearPart e@(EFloating (ATan _)) _ = (e,0)
extractLinearPart e@(EFloating (ACos _)) _ = (e,0)
extractLinearPart e@(EFloating (Sinh _)) _ = (e,0)
extractLinearPart e@(EFloating (Cosh _)) _ = (e,0)
extractLinearPart e@(EFloating (Tanh _)) _ = (e,0)
extractLinearPart e@(EFloating (ASinh _)) _ = (e,0)
extractLinearPart e@(EFloating (ATanh _)) _ = (e,0)
extractLinearPart e@(EFloating (ACosh _)) _ = (e,0)