{-# language FlexibleInstances, DeriveFunctor #-}
{-# language ScopedTypeVariables #-}
{-# language RankNTypes #-}
{-# language OverloadedStrings #-}
{-# language LambdaCase #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SRTree.Internal 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  FlexibleInstances, DeriveFunctor, ScopedTypeVariables
--
-- Expression tree for Symbolic Regression
--
-----------------------------------------------------------------------------

module Data.SRTree.Internal
         ( SRTree(..)
         , Function(..)
         , Op(..)
         , param
         , var
         , constv
         , arity
         , getChildren
         , childrenOf
         , replaceChildren
         , getOperator
         , countNodes
         , countVarNodes
         , countConsts
         , countParams
         , countOccurrences
         , countUniqueTokens
         , numberOfVars
         , getIntConsts
         , relabelParams
         , relabelVars
         , constsToParam
         , floatConstsToParam
         , paramsToConst
         , Fix (..)
         )
         where

import Control.Monad.State (MonadState (get), State, evalState, modify)
import Data.SRTree.Recursion (Fix (..), cata, cataM)
import qualified Data.Set as S
import Data.String (IsString (..))
import Text.Read (readMaybe)

-- | Tree structure to be used with Symbolic Regression algorithms.
-- This structure is a fixed point of a n-ary tree. 
data SRTree val =
   Var Int     -- ^ index of the variables
 | Param Int   -- ^ index of the parameter
 | Const Double -- ^ constant value, can be converted to a parameter
 -- | IConst Int   -- TODO: integer constant
 -- | RConst Ratio  -- TODO: rational constant
 | Uni Function val -- ^ univariate function
 | Bin Op val val -- ^ binary operator
 deriving (Int -> SRTree val -> ShowS
[SRTree val] -> ShowS
SRTree val -> String
(Int -> SRTree val -> ShowS)
-> (SRTree val -> String)
-> ([SRTree val] -> ShowS)
-> Show (SRTree val)
forall val. Show val => Int -> SRTree val -> ShowS
forall val. Show val => [SRTree val] -> ShowS
forall val. Show val => SRTree val -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall val. Show val => Int -> SRTree val -> ShowS
showsPrec :: Int -> SRTree val -> ShowS
$cshow :: forall val. Show val => SRTree val -> String
show :: SRTree val -> String
$cshowList :: forall val. Show val => [SRTree val] -> ShowS
showList :: [SRTree val] -> ShowS
Show, SRTree val -> SRTree val -> Bool
(SRTree val -> SRTree val -> Bool)
-> (SRTree val -> SRTree val -> Bool) -> Eq (SRTree val)
forall val. Eq val => SRTree val -> SRTree val -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall val. Eq val => SRTree val -> SRTree val -> Bool
== :: SRTree val -> SRTree val -> Bool
$c/= :: forall val. Eq val => SRTree val -> SRTree val -> Bool
/= :: SRTree val -> SRTree val -> Bool
Eq, Eq (SRTree val)
Eq (SRTree val) =>
(SRTree val -> SRTree val -> Ordering)
-> (SRTree val -> SRTree val -> Bool)
-> (SRTree val -> SRTree val -> Bool)
-> (SRTree val -> SRTree val -> Bool)
-> (SRTree val -> SRTree val -> Bool)
-> (SRTree val -> SRTree val -> SRTree val)
-> (SRTree val -> SRTree val -> SRTree val)
-> Ord (SRTree val)
SRTree val -> SRTree val -> Bool
SRTree val -> SRTree val -> Ordering
SRTree val -> SRTree val -> SRTree val
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall val. Ord val => Eq (SRTree val)
forall val. Ord val => SRTree val -> SRTree val -> Bool
forall val. Ord val => SRTree val -> SRTree val -> Ordering
forall val. Ord val => SRTree val -> SRTree val -> SRTree val
$ccompare :: forall val. Ord val => SRTree val -> SRTree val -> Ordering
compare :: SRTree val -> SRTree val -> Ordering
$c< :: forall val. Ord val => SRTree val -> SRTree val -> Bool
< :: SRTree val -> SRTree val -> Bool
$c<= :: forall val. Ord val => SRTree val -> SRTree val -> Bool
<= :: SRTree val -> SRTree val -> Bool
$c> :: forall val. Ord val => SRTree val -> SRTree val -> Bool
> :: SRTree val -> SRTree val -> Bool
$c>= :: forall val. Ord val => SRTree val -> SRTree val -> Bool
>= :: SRTree val -> SRTree val -> Bool
$cmax :: forall val. Ord val => SRTree val -> SRTree val -> SRTree val
max :: SRTree val -> SRTree val -> SRTree val
$cmin :: forall val. Ord val => SRTree val -> SRTree val -> SRTree val
min :: SRTree val -> SRTree val -> SRTree val
Ord, (forall a b. (a -> b) -> SRTree a -> SRTree b)
-> (forall a b. a -> SRTree b -> SRTree a) -> Functor SRTree
forall a b. a -> SRTree b -> SRTree a
forall a b. (a -> b) -> SRTree a -> SRTree b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> SRTree a -> SRTree b
fmap :: forall a b. (a -> b) -> SRTree a -> SRTree b
$c<$ :: forall a b. a -> SRTree b -> SRTree a
<$ :: forall a b. a -> SRTree b -> SRTree a
Functor)

-- | Supported operators
data Op = Add | Sub | Mul | Div | Power | PowerAbs | AQ
    deriving (Int -> Op -> ShowS
[Op] -> ShowS
Op -> String
(Int -> Op -> ShowS)
-> (Op -> String) -> ([Op] -> ShowS) -> Show Op
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Op -> ShowS
showsPrec :: Int -> Op -> ShowS
$cshow :: Op -> String
show :: Op -> String
$cshowList :: [Op] -> ShowS
showList :: [Op] -> ShowS
Show, ReadPrec [Op]
ReadPrec Op
Int -> ReadS Op
ReadS [Op]
(Int -> ReadS Op)
-> ReadS [Op] -> ReadPrec Op -> ReadPrec [Op] -> Read Op
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Op
readsPrec :: Int -> ReadS Op
$creadList :: ReadS [Op]
readList :: ReadS [Op]
$creadPrec :: ReadPrec Op
readPrec :: ReadPrec Op
$creadListPrec :: ReadPrec [Op]
readListPrec :: ReadPrec [Op]
Read, Op -> Op -> Bool
(Op -> Op -> Bool) -> (Op -> Op -> Bool) -> Eq Op
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Op -> Op -> Bool
== :: Op -> Op -> Bool
$c/= :: Op -> Op -> Bool
/= :: Op -> Op -> Bool
Eq, Eq Op
Eq Op =>
(Op -> Op -> Ordering)
-> (Op -> Op -> Bool)
-> (Op -> Op -> Bool)
-> (Op -> Op -> Bool)
-> (Op -> Op -> Bool)
-> (Op -> Op -> Op)
-> (Op -> Op -> Op)
-> Ord Op
Op -> Op -> Bool
Op -> Op -> Ordering
Op -> Op -> Op
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Op -> Op -> Ordering
compare :: Op -> Op -> Ordering
$c< :: Op -> Op -> Bool
< :: Op -> Op -> Bool
$c<= :: Op -> Op -> Bool
<= :: Op -> Op -> Bool
$c> :: Op -> Op -> Bool
> :: Op -> Op -> Bool
$c>= :: Op -> Op -> Bool
>= :: Op -> Op -> Bool
$cmax :: Op -> Op -> Op
max :: Op -> Op -> Op
$cmin :: Op -> Op -> Op
min :: Op -> Op -> Op
Ord, Int -> Op
Op -> Int
Op -> [Op]
Op -> Op
Op -> Op -> [Op]
Op -> Op -> Op -> [Op]
(Op -> Op)
-> (Op -> Op)
-> (Int -> Op)
-> (Op -> Int)
-> (Op -> [Op])
-> (Op -> Op -> [Op])
-> (Op -> Op -> [Op])
-> (Op -> Op -> Op -> [Op])
-> Enum Op
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Op -> Op
succ :: Op -> Op
$cpred :: Op -> Op
pred :: Op -> Op
$ctoEnum :: Int -> Op
toEnum :: Int -> Op
$cfromEnum :: Op -> Int
fromEnum :: Op -> Int
$cenumFrom :: Op -> [Op]
enumFrom :: Op -> [Op]
$cenumFromThen :: Op -> Op -> [Op]
enumFromThen :: Op -> Op -> [Op]
$cenumFromTo :: Op -> Op -> [Op]
enumFromTo :: Op -> Op -> [Op]
$cenumFromThenTo :: Op -> Op -> Op -> [Op]
enumFromThenTo :: Op -> Op -> Op -> [Op]
Enum)

-- | Supported functions
data Function =
    Id
  | Abs
  | Sin
  | Cos
  | Tan
  | Sinh
  | Cosh
  | Tanh
  | ASin
  | ACos
  | ATan
  | ASinh
  | ACosh
  | ATanh
  | Sqrt
  | SqrtAbs
  | Cbrt
  | Square
  | Log
  | LogAbs
  | Exp
  | Recip
  | Cube
     deriving (Int -> Function -> ShowS
[Function] -> ShowS
Function -> String
(Int -> Function -> ShowS)
-> (Function -> String) -> ([Function] -> ShowS) -> Show Function
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Function -> ShowS
showsPrec :: Int -> Function -> ShowS
$cshow :: Function -> String
show :: Function -> String
$cshowList :: [Function] -> ShowS
showList :: [Function] -> ShowS
Show, ReadPrec [Function]
ReadPrec Function
Int -> ReadS Function
ReadS [Function]
(Int -> ReadS Function)
-> ReadS [Function]
-> ReadPrec Function
-> ReadPrec [Function]
-> Read Function
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS Function
readsPrec :: Int -> ReadS Function
$creadList :: ReadS [Function]
readList :: ReadS [Function]
$creadPrec :: ReadPrec Function
readPrec :: ReadPrec Function
$creadListPrec :: ReadPrec [Function]
readListPrec :: ReadPrec [Function]
Read, Function -> Function -> Bool
(Function -> Function -> Bool)
-> (Function -> Function -> Bool) -> Eq Function
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Function -> Function -> Bool
== :: Function -> Function -> Bool
$c/= :: Function -> Function -> Bool
/= :: Function -> Function -> Bool
Eq, Eq Function
Eq Function =>
(Function -> Function -> Ordering)
-> (Function -> Function -> Bool)
-> (Function -> Function -> Bool)
-> (Function -> Function -> Bool)
-> (Function -> Function -> Bool)
-> (Function -> Function -> Function)
-> (Function -> Function -> Function)
-> Ord Function
Function -> Function -> Bool
Function -> Function -> Ordering
Function -> Function -> Function
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Function -> Function -> Ordering
compare :: Function -> Function -> Ordering
$c< :: Function -> Function -> Bool
< :: Function -> Function -> Bool
$c<= :: Function -> Function -> Bool
<= :: Function -> Function -> Bool
$c> :: Function -> Function -> Bool
> :: Function -> Function -> Bool
$c>= :: Function -> Function -> Bool
>= :: Function -> Function -> Bool
$cmax :: Function -> Function -> Function
max :: Function -> Function -> Function
$cmin :: Function -> Function -> Function
min :: Function -> Function -> Function
Ord, Int -> Function
Function -> Int
Function -> [Function]
Function -> Function
Function -> Function -> [Function]
Function -> Function -> Function -> [Function]
(Function -> Function)
-> (Function -> Function)
-> (Int -> Function)
-> (Function -> Int)
-> (Function -> [Function])
-> (Function -> Function -> [Function])
-> (Function -> Function -> [Function])
-> (Function -> Function -> Function -> [Function])
-> Enum Function
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: Function -> Function
succ :: Function -> Function
$cpred :: Function -> Function
pred :: Function -> Function
$ctoEnum :: Int -> Function
toEnum :: Int -> Function
$cfromEnum :: Function -> Int
fromEnum :: Function -> Int
$cenumFrom :: Function -> [Function]
enumFrom :: Function -> [Function]
$cenumFromThen :: Function -> Function -> [Function]
enumFromThen :: Function -> Function -> [Function]
$cenumFromTo :: Function -> Function -> [Function]
enumFromTo :: Function -> Function -> [Function]
$cenumFromThenTo :: Function -> Function -> Function -> [Function]
enumFromThenTo :: Function -> Function -> Function -> [Function]
Enum)

-- | create a tree with a single node representing a variable
var :: Int -> Fix SRTree
var :: Int -> Fix SRTree
var Int
ix = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix)

-- | create a tree with a single node representing a parameter
param :: Int -> Fix SRTree
param :: Int -> Fix SRTree
param Int
ix = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix)

-- | create a tree with a single node representing a constant value
constv :: Double -> Fix SRTree
constv :: Double -> Fix SRTree
constv Double
x = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x)

-- | the instance of `IsString` allows us to
-- create a tree using a more practical notation:
--
-- >>> :t  "x0" + "t0" * sin("x1" * "t1")
-- Fix SRTree
--
instance IsString (Fix SRTree) where 
    fromString :: String -> Fix SRTree
fromString [] = String -> Fix SRTree
forall a. HasCallStack => String -> a
error String
"empty string for SRTree"
    fromString (Char
'x':String
ix) = case String -> Maybe Int
forall a. Read a => String -> Maybe a
readMaybe String
ix of 
                            Just Int
iy -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
iy)
                            Maybe Int
Nothing -> String -> Fix SRTree
forall a. HasCallStack => String -> a
error String
"wrong format for variable. It should be xi where i is an index. Ex.: \"x0\", \"x1\"."
    fromString (Char
't':String
ix) = case String -> Maybe Int
forall a. Read a => String -> Maybe a
readMaybe String
ix of 
                            Just Int
iy -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
iy)
                            Maybe Int
Nothing -> String -> Fix SRTree
forall a. HasCallStack => String -> a
error String
"wrong format for parameter. It should be ti where i is an index. Ex.: \"t0\", \"t1\"."
    fromString String
_        = String -> Fix SRTree
forall a. HasCallStack => String -> a
error String
"A string can represent a variable or a parameter following the format xi or ti, respectivelly, where i is the index. Ex.: \"x0\", \"t0\"."

instance Num (Fix SRTree) where
  Fix (Const Double
0) + :: Fix SRTree -> Fix SRTree -> Fix SRTree
+ Fix SRTree
r = Fix SRTree
r
  Fix SRTree
l + Fix (Const Double
0) = Fix SRTree
l
  Fix (Const Double
c1) + Fix (Const Double
c2) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Double -> SRTree (Fix SRTree)) -> Double -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> Fix SRTree) -> Double -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double
c1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
c2
  Fix SRTree
l + Fix SRTree
r                   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Fix SRTree
l Fix SRTree
r
  {-# INLINE (+) #-}

  Fix SRTree
l - :: Fix SRTree -> Fix SRTree -> Fix SRTree
- Fix (Const Double
0) = Fix SRTree
l
  Fix (Const Double
0) - Fix SRTree
r = Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
negate Fix SRTree
r
  Fix (Const Double
c1) - Fix (Const Double
c2) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Double -> SRTree (Fix SRTree)) -> Double -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> Fix SRTree) -> Double -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double
c1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
c2
  Fix SRTree
l - Fix SRTree
r                   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Fix SRTree
l Fix SRTree
r
  {-# INLINE (-) #-}

  Fix (Const Double
0) * :: Fix SRTree -> Fix SRTree -> Fix SRTree
* Fix SRTree
_ = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
0)
  Fix SRTree
_ * Fix (Const Double
0) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
0)
  Fix (Const Double
1) * Fix SRTree
r = Fix SRTree
r
  Fix SRTree
l * Fix (Const Double
1) = Fix SRTree
l
  Fix (Const Double
c1) * Fix (Const Double
c2) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Double -> SRTree (Fix SRTree)) -> Double -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> Fix SRTree) -> Double -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double
c1 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
c2
  Fix SRTree
l * Fix SRTree
r                   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Fix SRTree
l Fix SRTree
r
  {-# INLINE (*) #-}

  abs :: Fix SRTree -> Fix SRTree
abs = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Abs
  {-# INLINE abs #-}

  negate :: Fix SRTree -> Fix SRTree
negate (Fix (Const Double
x)) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> Double
forall a. Num a => a -> a
negate Double
x)
  negate Fix SRTree
t         = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (-Double
1)) Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
* Fix SRTree
t
  {-# INLINE negate #-}

  signum :: Fix SRTree -> Fix SRTree
signum Fix SRTree
t    = case Fix SRTree
t of
                  Fix (Const Double
x) -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Double -> SRTree (Fix SRTree)) -> Double -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> Fix SRTree) -> Double -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Num a => a -> a
signum Double
x
                  Fix SRTree
_       -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
0)
  fromInteger :: Integer -> Fix SRTree
fromInteger Integer
x = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
x)
  {-# INLINE fromInteger #-}

instance Fractional (Fix SRTree) where
  Fix SRTree
l / :: Fix SRTree -> Fix SRTree -> Fix SRTree
/ Fix (Const Double
1) = Fix SRTree
l
  Fix (Const Double
c1) / Fix (Const Double
c2) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Double -> SRTree (Fix SRTree)) -> Double -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> Fix SRTree) -> Double -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double
c1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
c2
  Fix SRTree
l / Fix SRTree
r                   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Fix SRTree
l Fix SRTree
r
  {-# INLINE (/) #-}

  recip :: Fix SRTree -> Fix SRTree
recip = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Recip
  {-# INLINE recip #-}

  fromRational :: Rational -> Fix SRTree
fromRational = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Rational -> SRTree (Fix SRTree)) -> Rational -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> SRTree (Fix SRTree))
-> (Rational -> Double) -> Rational -> SRTree (Fix SRTree)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> Double
forall a. Fractional a => Rational -> a
fromRational
  {-# INLINE fromRational #-}

instance Floating (Fix SRTree) where
  pi :: Fix SRTree
pi      = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const  Double
forall a. Floating a => a
pi
  {-# INLINE pi #-}
  exp :: Fix SRTree -> Fix SRTree
exp     = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Exp
  {-# INLINE exp #-}
  log :: Fix SRTree -> Fix SRTree
log     = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Log
  {-# INLINE log #-}
  sqrt :: Fix SRTree -> Fix SRTree
sqrt    = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Sqrt
  {-# INLINE sqrt #-}
  sin :: Fix SRTree -> Fix SRTree
sin     = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Sin
  {-# INLINE sin #-}
  cos :: Fix SRTree -> Fix SRTree
cos     = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Cos
  {-# INLINE cos #-}
  tan :: Fix SRTree -> Fix SRTree
tan     = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Tan
  {-# INLINE tan #-}
  asin :: Fix SRTree -> Fix SRTree
asin    = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
ASin
  {-# INLINE asin #-}
  acos :: Fix SRTree -> Fix SRTree
acos    = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
ACos
  {-# INLINE acos #-}
  atan :: Fix SRTree -> Fix SRTree
atan    = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
ATan
  {-# INLINE atan #-}
  sinh :: Fix SRTree -> Fix SRTree
sinh    = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Sinh
  {-# INLINE sinh #-}
  cosh :: Fix SRTree -> Fix SRTree
cosh    = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Cosh
  {-# INLINE cosh #-}
  tanh :: Fix SRTree -> Fix SRTree
tanh    = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
Tanh
  {-# INLINE tanh #-}
  asinh :: Fix SRTree -> Fix SRTree
asinh   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
ASinh
  {-# INLINE asinh #-}
  acosh :: Fix SRTree -> Fix SRTree
acosh   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
ACosh
  {-# INLINE acosh #-}
  atanh :: Fix SRTree -> Fix SRTree
atanh   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
ATanh
  {-# INLINE atanh #-}

  Fix SRTree
l ** :: Fix SRTree -> Fix SRTree -> Fix SRTree
** Fix (Const Double
1) = Fix SRTree
l
  Fix SRTree
l ** Fix (Const Double
0) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
1)
  Fix SRTree
l ** Fix SRTree
r  = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Fix SRTree
l Fix SRTree
r
  {-# INLINE (**) #-}

  logBase :: Fix SRTree -> Fix SRTree -> Fix SRTree
logBase Fix SRTree
l (Fix (Const Double
1)) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
0)
  logBase Fix SRTree
l Fix SRTree
r = Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
l Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log Fix SRTree
r
  {-# INLINE logBase #-}

instance Foldable SRTree where 
    foldMap :: forall m a. Monoid m => (a -> m) -> SRTree a -> m
foldMap a -> m
f =
        \case
          Bin Op
op a
l a
r -> a -> m
f a
l m -> m -> m
forall a. Semigroup a => a -> a -> a
<> a -> m
f a
r
          Uni Function
g a
t    -> a -> m
f a
t 
          SRTree a
_          -> m
forall a. Monoid a => a
mempty 

instance Traversable SRTree where 
    traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> SRTree a -> f (SRTree b)
traverse a -> f b
f = 
        \case 
          Bin Op
op a
l a
r -> Op -> b -> b -> SRTree b
forall val. Op -> val -> val -> SRTree val
Bin Op
op (b -> b -> SRTree b) -> f b -> f (b -> SRTree b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
l f (b -> SRTree b) -> f b -> f (SRTree b)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
r 
          Uni Function
g a
t    -> Function -> b -> SRTree b
forall val. Function -> val -> SRTree val
Uni Function
g (b -> SRTree b) -> f b -> f (SRTree b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
t 
          Var Int
ix     -> SRTree b -> f (SRTree b)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> SRTree b
forall val. Int -> SRTree val
Var Int
ix) 
          Param Int
ix   -> SRTree b -> f (SRTree b)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> SRTree b
forall val. Int -> SRTree val
Param Int
ix) 
          Const Double
x    -> SRTree b -> f (SRTree b)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> SRTree b
forall val. Double -> SRTree val
Const Double
x) 
    sequence :: forall (m :: * -> *) a. Monad m => SRTree (m a) -> m (SRTree a)
sequence =
        \case
          Bin Op
op m a
l m a
r -> Op -> a -> a -> SRTree a
forall val. Op -> val -> val -> SRTree val
Bin Op
op (a -> a -> SRTree a) -> m a -> m (a -> SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
l m (a -> SRTree a) -> m a -> m (SRTree a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> m a
r 
          Uni Function
g m a
t    -> Function -> a -> SRTree a
forall val. Function -> val -> SRTree val
Uni Function
g (a -> SRTree a) -> m a -> m (SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
t 
          Var Int
ix     -> SRTree a -> m (SRTree a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> SRTree a
forall val. Int -> SRTree val
Var Int
ix) 
          Param Int
ix   -> SRTree a -> m (SRTree a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> SRTree a
forall val. Int -> SRTree val
Param Int
ix) 
          Const Double
x    -> SRTree a -> m (SRTree a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> SRTree a
forall val. Double -> SRTree val
Const Double
x) 

-- | Arity of the current node
arity :: Fix SRTree -> Int
arity :: Fix SRTree -> Int
arity = (SRTree Int -> Int) -> Fix SRTree -> Int
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree Int -> Int
forall {a} {val}. Num a => SRTree val -> a
alg
  where
    alg :: SRTree val -> a
alg Var {}      = a
0
    alg Param {}    = a
0
    alg Const {}    = a
0
    alg Uni {}      = a
1
    alg Bin {}      = a
2
{-# INLINE arity #-}

-- | Get the children of a node. Returns an empty list in case of a leaf node.
--
-- >>> map showExpr . getChildren $ "x0" + 2 
-- ["x0", 2]
--
getChildren :: Fix SRTree -> [Fix SRTree]
getChildren :: Fix SRTree -> [Fix SRTree]
getChildren (Fix (Var {})) = []
getChildren (Fix (Param {})) = []
getChildren (Fix (Const {})) = []
getChildren (Fix (Uni Function
_ Fix SRTree
t)) = [Fix SRTree
t]
getChildren (Fix (Bin Op
_ Fix SRTree
l Fix SRTree
r)) = [Fix SRTree
l, Fix SRTree
r]
{-# INLINE getChildren #-}

-- | Get the children of an unfixed node 
-- 
childrenOf :: SRTree a -> [a] 
childrenOf :: forall a. SRTree a -> [a]
childrenOf = 
    \case 
      Uni Function
_ a
t   -> [a
t] 
      Bin Op
_ a
l a
r -> [a
l, a
r] 
      SRTree a
_         -> []

-- | replaces the children with elements from a list 
replaceChildren :: [a] -> SRTree b -> SRTree a
replaceChildren :: forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [a
l, a
r] (Bin Op
op b
_ b
_) = Op -> a -> a -> SRTree a
forall val. Op -> val -> val -> SRTree val
Bin Op
op a
l a
r
replaceChildren [a
t]    (Uni Function
f b
_)    = Function -> a -> SRTree a
forall val. Function -> val -> SRTree val
Uni Function
f a
t
replaceChildren [a]
_      (Var Int
ix)     = Int -> SRTree a
forall val. Int -> SRTree val
Var Int
ix
replaceChildren [a]
_      (Param Int
ix)   = Int -> SRTree a
forall val. Int -> SRTree val
Param Int
ix
replaceChildren [a]
_      (Const Double
x)    = Double -> SRTree a
forall val. Double -> SRTree val
Const Double
x
replaceChildren [a]
xs     SRTree b
n            = String -> SRTree a
forall a. HasCallStack => String -> a
error String
"ERROR: trying to replace children with not enough elements."
{-# INLINE replaceChildren #-}

-- | returns a node containing the operator and () as children
getOperator :: SRTree a -> SRTree ()
getOperator :: forall a. SRTree a -> SRTree ()
getOperator (Bin Op
op a
_ a
_) = Op -> () -> () -> SRTree ()
forall val. Op -> val -> val -> SRTree val
Bin Op
op () ()
getOperator (Uni Function
f a
_)    = Function -> () -> SRTree ()
forall val. Function -> val -> SRTree val
Uni Function
f ()
getOperator (Var Int
ix)     = Int -> SRTree ()
forall val. Int -> SRTree val
Var Int
ix
getOperator (Param Int
ix)   = Int -> SRTree ()
forall val. Int -> SRTree val
Param Int
ix
getOperator (Const Double
x)    = Double -> SRTree ()
forall val. Double -> SRTree val
Const Double
x
{-# INLINE getOperator #-}

-- | Count the number of nodes in a tree.
--
-- >>> countNodes $ "x0" + 2
-- 3
countNodes :: Num a => Fix SRTree -> a
countNodes :: forall a. Num a => Fix SRTree -> a
countNodes = (SRTree a -> a) -> Fix SRTree -> a
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree a -> a
forall a. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg Var   {}    = a
1
      alg Param {}    = a
1
      alg Const {}    = a
1
      alg (Uni Function
_ a
t)   = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t
      alg (Bin Op
_ a
l a
r) = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
l a -> a -> a
forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countNodes #-}

-- | Count the number of `Var` nodes
--
-- >>> countVarNodes $ "x0" + 2 * ("x0" - sin "x1")
-- 3
countVarNodes :: Num a => Fix SRTree -> a
countVarNodes :: forall a. Num a => Fix SRTree -> a
countVarNodes = (SRTree a -> a) -> Fix SRTree -> a
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree a -> a
forall a. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg Var {} = a
1
      alg Param {} = a
0
      alg Const {} = a
0
      alg (Uni Function
_ a
t) = a
0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t
      alg (Bin Op
_ a
l a
r) = a
0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
l a -> a -> a
forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countVarNodes #-}

-- | Count the number of `Param` nodes
--
-- >>> countParams $ "x0" + "t0" * sin ("t1" + "x1") - "t0"
-- 3
countParams :: Num a => Fix SRTree -> a
countParams :: forall a. Num a => Fix SRTree -> a
countParams = (SRTree a -> a) -> Fix SRTree -> a
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree a -> a
forall a. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg Var {} = a
0
      alg Param {} = a
1
      alg Const {} = a
0
      alg (Uni Function
_ a
t) = a
0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t
      alg (Bin Op
_ a
l a
r) = a
0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
l a -> a -> a
forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countParams #-}

-- | Count the number of const nodes
--
-- >>> countConsts $ "x0"* 2 + 3 * sin "x0"
-- 2
countConsts :: Num a => Fix SRTree -> a
countConsts :: forall a. Num a => Fix SRTree -> a
countConsts = (SRTree a -> a) -> Fix SRTree -> a
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree a -> a
forall a. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg Var {} = a
0
      alg Param {} = a
0
      alg Const {} = a
1
      alg (Uni Function
_ a
t) = a
0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
t
      alg (Bin Op
_ a
l a
r) = a
0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
l a -> a -> a
forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countConsts #-}

-- | Count the occurrences of variable indexed as `ix`
--
-- >>> countOccurrences 0 $ "x0"* 2 + 3 * sin "x0" + "x1"
-- 2
countOccurrences :: Num a => Int -> Fix SRTree -> a
countOccurrences :: forall a. Num a => Int -> Fix SRTree -> a
countOccurrences Int
ix = (SRTree a -> a) -> Fix SRTree -> a
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree a -> a
forall a. Num a => SRTree a -> a
alg
  where
      alg :: SRTree a -> a
alg (Var Int
iy) = if Int
ix Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
iy then a
1 else a
0
      alg Param {} = a
0
      alg Const {} = a
0
      alg (Uni Function
_ a
t) = a
t
      alg (Bin Op
_ a
l a
r) = a
l a -> a -> a
forall a. Num a => a -> a -> a
+ a
r
{-# INLINE countOccurrences #-}

-- | counts the number of unique tokens 
--
-- >>> countUniqueTokens $ "x0" + ("x1" * "x0" - sin ("x0" ** 2))
-- 8
countUniqueTokens :: Num a => Fix SRTree -> a
countUniqueTokens :: forall a. Num a => Fix SRTree -> a
countUniqueTokens = (Set Op, Set Function, Set Int, Set Integer, Set Integer) -> a
forall {b} {t :: * -> *} {t :: * -> *} {t :: * -> *} {t :: * -> *}
       {t :: * -> *} {a} {a} {a} {a} {a}.
(Num b, Foldable t, Foldable t, Foldable t, Foldable t,
 Foldable t) =>
(t a, t a, t a, t a, t a) -> b
len ((Set Op, Set Function, Set Int, Set Integer, Set Integer) -> a)
-> (Fix SRTree
    -> (Set Op, Set Function, Set Int, Set Integer, Set Integer))
-> Fix SRTree
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (Set Op, Set Function, Set Int, Set Integer, Set Integer)
 -> (Set Op, Set Function, Set Int, Set Integer, Set Integer))
-> Fix SRTree
-> (Set Op, Set Function, Set Int, Set Integer, Set Integer)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Set Op, Set Function, Set Int, Set Integer, Set Integer)
-> (Set Op, Set Function, Set Int, Set Integer, Set Integer)
forall {a} {a}.
(Ord a, Ord a, Num a, Num a) =>
SRTree (Set Op, Set Function, Set Int, Set a, Set a)
-> (Set Op, Set Function, Set Int, Set a, Set a)
alg
  where
    len :: (t a, t a, t a, t a, t a) -> b
len (t a
a, t a
b, t a
c, t a
d, t a
e) = Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> b) -> Int -> b
forall a b. (a -> b) -> a -> b
$ t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
e
    alg :: SRTree (Set Op, Set Function, Set Int, Set a, Set a)
-> (Set Op, Set Function, Set Int, Set a, Set a)
alg (Var Int
ix)        = (Set Op
forall a. Monoid a => a
mempty, Set Function
forall a. Monoid a => a
mempty, Int -> Set Int
forall a. a -> Set a
S.singleton Int
ix, Set a
forall a. Monoid a => a
mempty, Set a
forall a. Monoid a => a
mempty)
    alg (Param Int
_)       = (Set Op
forall a. Monoid a => a
mempty, Set Function
forall a. Monoid a => a
mempty, Set Int
forall a. Monoid a => a
mempty, a -> Set a
forall a. a -> Set a
S.singleton a
1, Set a
forall a. Monoid a => a
mempty)
    alg (Const Double
_)       = (Set Op
forall a. Monoid a => a
mempty, Set Function
forall a. Monoid a => a
mempty, Set Int
forall a. Monoid a => a
mempty, Set a
forall a. Monoid a => a
mempty, a -> Set a
forall a. a -> Set a
S.singleton a
1)
    alg (Uni Function
f (Set Op, Set Function, Set Int, Set a, Set a)
t)       = (Set Op
forall a. Monoid a => a
mempty, Function -> Set Function
forall a. a -> Set a
S.singleton Function
f, Set Int
forall a. Monoid a => a
mempty, Set a
forall a. Monoid a => a
mempty, Set a
forall a. Monoid a => a
mempty) (Set Op, Set Function, Set Int, Set a, Set a)
-> (Set Op, Set Function, Set Int, Set a, Set a)
-> (Set Op, Set Function, Set Int, Set a, Set a)
forall a. Semigroup a => a -> a -> a
<> (Set Op, Set Function, Set Int, Set a, Set a)
t
    alg (Bin Op
op (Set Op, Set Function, Set Int, Set a, Set a)
l (Set Op, Set Function, Set Int, Set a, Set a)
r)    = (Op -> Set Op
forall a. a -> Set a
S.singleton Op
op, Set Function
forall a. Monoid a => a
mempty, Set Int
forall a. Monoid a => a
mempty, Set a
forall a. Monoid a => a
mempty, Set a
forall a. Monoid a => a
mempty) (Set Op, Set Function, Set Int, Set a, Set a)
-> (Set Op, Set Function, Set Int, Set a, Set a)
-> (Set Op, Set Function, Set Int, Set a, Set a)
forall a. Semigroup a => a -> a -> a
<> (Set Op, Set Function, Set Int, Set a, Set a)
l (Set Op, Set Function, Set Int, Set a, Set a)
-> (Set Op, Set Function, Set Int, Set a, Set a)
-> (Set Op, Set Function, Set Int, Set a, Set a)
forall a. Semigroup a => a -> a -> a
<> (Set Op, Set Function, Set Int, Set a, Set a)
r
{-# INLINE countUniqueTokens #-}

-- | return the number of unique variables 
-- 
-- >>> numberOfVars $ "x0" + 2 * ("x0" - sin "x1")
-- 2
numberOfVars :: Num a => Fix SRTree -> a
numberOfVars :: forall a. Num a => Fix SRTree -> a
numberOfVars = Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> a) -> (Fix SRTree -> Int) -> Fix SRTree -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set Int -> Int
forall a. Set a -> Int
S.size (Set Int -> Int) -> (Fix SRTree -> Set Int) -> Fix SRTree -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (Set Int) -> Set Int) -> Fix SRTree -> Set Int
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Set Int) -> Set Int
alg
  where
    alg :: SRTree (Set Int) -> Set Int
alg (Uni Function
f Set Int
t)    = Set Int
t
    alg (Bin Op
op Set Int
l Set Int
r) = Set Int
l Set Int -> Set Int -> Set Int
forall a. Semigroup a => a -> a -> a
<> Set Int
r
    alg (Var Int
ix)     = Int -> Set Int
forall a. a -> Set a
S.singleton Int
ix
    alg SRTree (Set Int)
_            = Set Int
forall a. Monoid a => a
mempty
{-# INLINE numberOfVars #-}

-- | returns the integer constants. We assume an integer constant 
-- as those values in which `floor x == ceiling x`.
--
-- >>> getIntConsts $ "x0" + 2 * "x1" ** 3 - 3.14
-- [2.0,3.0]
getIntConsts :: Fix SRTree -> [Double]
getIntConsts :: Fix SRTree -> [Double]
getIntConsts = (SRTree [Double] -> [Double]) -> Fix SRTree -> [Double]
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree [Double] -> [Double]
alg
  where
    alg :: SRTree [Double] -> [Double]
alg (Uni Function
f [Double]
t)    = [Double]
t
    alg (Bin Op
op [Double]
l [Double]
r) = [Double]
l [Double] -> [Double] -> [Double]
forall a. Semigroup a => a -> a -> a
<> [Double]
r
    alg (Var Int
ix)     = []
    alg (Param Int
_)    = []
    alg (Const Double
x)    = [Double
x | Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
x Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling Double
x]
{-# INLINE getIntConsts #-}

-- | Relabel the parameters indices incrementaly starting from 0
--
-- >>> showExpr . relabelParams $ "x0" + "t0" * sin ("t1" + "x1") - "t0" 
-- "x0" + "t0" * sin ("t1" + "x1") - "t2" 
relabelParams :: Fix SRTree -> Fix SRTree
relabelParams :: Fix SRTree -> Fix SRTree
relabelParams Fix SRTree
t = (forall x.
 SRTree (StateT Int Identity x) -> StateT Int Identity (SRTree x))
-> (SRTree (Fix SRTree) -> StateT Int Identity (Fix SRTree))
-> Fix SRTree
-> StateT Int Identity (Fix SRTree)
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree (StateT Int Identity x) -> StateT Int Identity (SRTree x)
forall x.
SRTree (StateT Int Identity x) -> StateT Int Identity (SRTree x)
forall (f :: * -> *) a.
Applicative f =>
SRTree (f a) -> f (SRTree a)
leftToRight SRTree (Fix SRTree) -> StateT Int Identity (Fix SRTree)
alg Fix SRTree
t StateT Int Identity (Fix SRTree) -> Int -> Fix SRTree
forall s a. State s a -> s -> a
`evalState` Int
0
  where
      -- | leftToRight (left to right) defines the sequence of processing
      leftToRight :: SRTree (f a) -> f (SRTree a)
leftToRight (Uni Function
f f a
mt)    = Function -> a -> SRTree a
forall val. Function -> val -> SRTree val
Uni Function
f (a -> SRTree a) -> f a -> f (SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
mt;
      leftToRight (Bin Op
f f a
ml f a
mr) = Op -> a -> a -> SRTree a
forall val. Op -> val -> val -> SRTree val
Bin Op
f (a -> a -> SRTree a) -> f a -> f (a -> SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
ml f (a -> SRTree a) -> f a -> f (SRTree a)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f a
mr
      leftToRight (Var Int
ix)      = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> SRTree a
forall val. Int -> SRTree val
Var Int
ix)
      leftToRight (Param Int
ix)    = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> SRTree a
forall val. Int -> SRTree val
Param Int
ix)
      leftToRight (Const Double
c)     = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> SRTree a
forall val. Double -> SRTree val
Const Double
c)

      -- | any time we reach a Param ix, it replaces ix with current state
      -- and increments one to the state.
      alg :: SRTree (Fix SRTree) -> State Int (Fix SRTree)
      alg :: SRTree (Fix SRTree) -> StateT Int Identity (Fix SRTree)
alg (Var Int
ix)    = Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT Int Identity (Fix SRTree))
-> Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ Int -> Fix SRTree
var Int
ix
      alg (Param Int
ix)  = do Int
iy <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get; (Int -> Int) -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1); Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Fix SRTree
param Int
iy)
      alg (Const Double
c)   = Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT Int Identity (Fix SRTree))
-> Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
c
      alg (Uni Function
f Fix SRTree
t)   = Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT Int Identity (Fix SRTree))
-> Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)
      alg (Bin Op
f Fix SRTree
l Fix SRTree
r) = Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT Int Identity (Fix SRTree))
-> Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
f Fix SRTree
l Fix SRTree
r)

-- | Relabel the parameters indices incrementaly starting from 0
--
-- >>> showExpr . relabelParams $ "x0" + "t0" * sin ("t1" + "x1") - "t0"
-- "x0" + "t0" * sin ("t1" + "x1") - "t2"
relabelVars :: Fix SRTree -> Fix SRTree
relabelVars :: Fix SRTree -> Fix SRTree
relabelVars Fix SRTree
t = (forall x.
 SRTree (StateT Int Identity x) -> StateT Int Identity (SRTree x))
-> (SRTree (Fix SRTree) -> StateT Int Identity (Fix SRTree))
-> Fix SRTree
-> StateT Int Identity (Fix SRTree)
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree (StateT Int Identity x) -> StateT Int Identity (SRTree x)
forall x.
SRTree (StateT Int Identity x) -> StateT Int Identity (SRTree x)
forall (f :: * -> *) a.
Applicative f =>
SRTree (f a) -> f (SRTree a)
leftToRight SRTree (Fix SRTree) -> StateT Int Identity (Fix SRTree)
alg Fix SRTree
t StateT Int Identity (Fix SRTree) -> Int -> Fix SRTree
forall s a. State s a -> s -> a
`evalState` Int
0
  where
      -- | leftToRight (left to right) defines the sequence of processing
      leftToRight :: SRTree (f a) -> f (SRTree a)
leftToRight (Uni Function
f f a
mt)    = Function -> a -> SRTree a
forall val. Function -> val -> SRTree val
Uni Function
f (a -> SRTree a) -> f a -> f (SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
mt;
      leftToRight (Bin Op
f f a
ml f a
mr) = Op -> a -> a -> SRTree a
forall val. Op -> val -> val -> SRTree val
Bin Op
f (a -> a -> SRTree a) -> f a -> f (a -> SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
ml f (a -> SRTree a) -> f a -> f (SRTree a)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f a
mr
      leftToRight (Var Int
ix)      = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> SRTree a
forall val. Int -> SRTree val
Var Int
ix)
      leftToRight (Param Int
ix)    = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> SRTree a
forall val. Int -> SRTree val
Param Int
ix)
      leftToRight (Const Double
c)     = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> SRTree a
forall val. Double -> SRTree val
Const Double
c)

      -- | any time we reach a Param ix, it replaces ix with current state
      -- and increments one to the state.
      alg :: SRTree (Fix SRTree) -> State Int (Fix SRTree)
      alg :: SRTree (Fix SRTree) -> StateT Int Identity (Fix SRTree)
alg (Var Int
ix)    = do Int
iy <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get; (Int -> Int) -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1); Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Fix SRTree
var Int
iy)
      alg (Param Int
ix)  = Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT Int Identity (Fix SRTree))
-> Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ Int -> Fix SRTree
param Int
ix
      alg (Const Double
c)   = Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT Int Identity (Fix SRTree))
-> Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
c
      alg (Uni Function
f Fix SRTree
t)   = Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT Int Identity (Fix SRTree))
-> Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)
      alg (Bin Op
f Fix SRTree
l Fix SRTree
r) = Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT Int Identity (Fix SRTree))
-> Fix SRTree -> StateT Int Identity (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
f Fix SRTree
l Fix SRTree
r)

-- | Change constant values to a parameter, returning the changed tree and a list
-- of parameter values
--
-- >>> snd . constsToParam $ "x0" * 2 + 3.14 * sin (5 * "x1")
-- [2.0,3.14,5.0]
constsToParam :: Fix SRTree -> (Fix SRTree, [Double])
constsToParam :: Fix SRTree -> (Fix SRTree, [Double])
constsToParam = (Fix SRTree -> Fix SRTree)
-> (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
forall {t} {a} {b}. (t -> a) -> (t, b) -> (a, b)
first Fix SRTree -> Fix SRTree
relabelParams ((Fix SRTree, [Double]) -> (Fix SRTree, [Double]))
-> (Fix SRTree -> (Fix SRTree, [Double]))
-> Fix SRTree
-> (Fix SRTree, [Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double]))
-> Fix SRTree -> (Fix SRTree, [Double])
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg
  where
      first :: (t -> a) -> (t, b) -> (a, b)
first t -> a
f (t
x, b
y) = (t -> a
f t
x, b
y)

      -- | If the tree already contains a parameter
      -- it will return a default value of 1.0
      -- whenever it finds a constant, it changes that
      -- to a parameter and adds its content to the singleton list
      alg :: SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg (Var Int
ix)    = (SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix, [])
      alg (Param Int
ix)  = (SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix, [Double
1.0])
      alg (Const Double
c)   = (SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
0, [Double
c])
      alg (Uni Function
f (Fix SRTree, [Double])
t)   = (SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f ((Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
t), (Fix SRTree, [Double]) -> [Double]
forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
t)
      alg (Bin Op
f (Fix SRTree, [Double])
l (Fix SRTree, [Double])
r) = (SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
f ((Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
l) ((Fix SRTree, [Double]) -> Fix SRTree
forall a b. (a, b) -> a
fst (Fix SRTree, [Double])
r)), (Fix SRTree, [Double]) -> [Double]
forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
l [Double] -> [Double] -> [Double]
forall a. Semigroup a => a -> a -> a
<> (Fix SRTree, [Double]) -> [Double]
forall a b. (a, b) -> b
snd (Fix SRTree, [Double])
r)

-- | Same as `constsToParam` but does not change constant values that
-- can be converted to integer without loss of precision
--
-- >>> snd . floatConstsToParam $ "x0" * 2 + 3.14 * sin (5 * "x1")
-- [3.14]
floatConstsToParam :: Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam :: Fix SRTree -> (Fix SRTree, [Double])
floatConstsToParam = (Fix SRTree -> Fix SRTree)
-> (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
forall {t} {a} {b}. (t -> a) -> (t, b) -> (a, b)
first Fix SRTree -> Fix SRTree
relabelParams ((Fix SRTree, [Double]) -> (Fix SRTree, [Double]))
-> (Fix SRTree -> (Fix SRTree, [Double]))
-> Fix SRTree
-> (Fix SRTree, [Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double]))
-> Fix SRTree -> (Fix SRTree, [Double])
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg
  where
      first :: (t -> a) -> (t, b) -> (a, b)
first t -> a
f (t
x, b
y)          = (t -> a
f t
x, b
y)
      combine :: (t -> t -> a) -> (t, b) -> (t, b) -> (a, b)
combine t -> t -> a
f (t
x, b
y) (t
z, b
w) = (t -> t -> a
f t
x t
z, b
y b -> b -> b
forall a. Semigroup a => a -> a -> a
<> b
w)
      isInt :: a -> Bool
isInt a
x                 = a -> Integer
forall b. Integral b => a -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor a
x Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== a -> Integer
forall b. Integral b => a -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling a
x

      alg :: SRTree (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
alg (Var Int
ix)    = (Int -> Fix SRTree
var Int
ix, [])
      alg (Param Int
ix)  = (Int -> Fix SRTree
param Int
ix, [Double
1.0])
      alg (Const Double
c)   = if Double -> Bool
forall {a}. RealFrac a => a -> Bool
isInt Double
c then (Double -> Fix SRTree
constv Double
c, []) else (Int -> Fix SRTree
param Int
0, [Double
c])
      alg (Uni Function
f (Fix SRTree, [Double])
t)   = (Fix SRTree -> Fix SRTree)
-> (Fix SRTree, [Double]) -> (Fix SRTree, [Double])
forall {t} {a} {b}. (t -> a) -> (t, b) -> (a, b)
first (SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f) (Fix SRTree, [Double])
t -- (Fix $ Uni f (fst t), snd t)
      alg (Bin Op
f (Fix SRTree, [Double])
l (Fix SRTree, [Double])
r) = (Fix SRTree -> Fix SRTree -> Fix SRTree)
-> (Fix SRTree, [Double])
-> (Fix SRTree, [Double])
-> (Fix SRTree, [Double])
forall {b} {t} {t} {a}.
Semigroup b =>
(t -> t -> a) -> (t, b) -> (t, b) -> (a, b)
combine ((SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree)
-> (Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree))
-> Fix SRTree
-> Fix SRTree
-> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
f) (Fix SRTree, [Double])
l (Fix SRTree, [Double])
r -- (Fix (Bin f (fst l) (fst r)), snd l <> snd r)

-- | Convert the parameters into constants in the tree
--
-- >>> showExpr . paramsToConst [1.1, 2.2, 3.3] $ "x0" + "t0" * sin ("t1" * "x0" - "t2")
-- x0 + 1.1 * sin(2.2 * x0 - 3.3)
paramsToConst :: [Double] -> Fix SRTree -> Fix SRTree
paramsToConst :: [Double] -> Fix SRTree -> Fix SRTree
paramsToConst [Double]
theta = (SRTree (Fix SRTree) -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix SRTree) -> Fix SRTree
alg
  where
      alg :: SRTree (Fix SRTree) -> Fix SRTree
alg (Var Int
ix)    = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix
      alg (Param Int
ix)  = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const ([Double]
theta [Double] -> Int -> Double
forall a. HasCallStack => [a] -> Int -> a
!! Int
ix)
      alg (Const Double
c)   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
c
      alg (Uni Function
f Fix SRTree
t)   = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t
      alg (Bin Op
f Fix SRTree
l Fix SRTree
r) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
f Fix SRTree
l Fix SRTree
r