{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SRTree.Print 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  
--
-- Conversion functions to display the expression trees in different formats.
--
-----------------------------------------------------------------------------
module Data.SRTree.Print 
         ( showExpr
         , showExprWithVars
         , printExpr
         , printExprWithVars
         , showTikz
         , printTikz
         , showPython
         , printPython
         , showLatex
         , printLatex
         )
         where

import Control.Monad.Reader (Reader, asks, runReader)
import Data.Char (toLower)
import Data.SRTree.Internal
import Data.SRTree.Recursion (cata)

-- | converts a tree with protected operators to
-- a conventional math tree
removeProtection :: Fix SRTree -> Fix SRTree
removeProtection :: Fix SRTree -> Fix SRTree
removeProtection = (SRTree (Fix SRTree) -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((SRTree (Fix SRTree) -> Fix SRTree) -> Fix SRTree -> Fix SRTree)
-> (SRTree (Fix SRTree) -> Fix SRTree) -> Fix SRTree -> Fix SRTree
forall a b. (a -> b) -> a -> b
$
  \case
     Var Int
ix -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var Int
ix)
     Param Int
ix -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Param Int
ix)
     Const Double
x -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x)
     Uni Function
SqrtAbs Fix SRTree
t -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
sqrt (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
abs Fix SRTree
t)
     Uni Function
LogAbs Fix SRTree
t -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
log (Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
abs Fix SRTree
t)
     Uni Function
Cube Fix SRTree
t -> Fix SRTree
t Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
3
     Uni Function
f Fix SRTree
t -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)
     Bin Op
AQ Fix SRTree
l Fix SRTree
r -> Fix SRTree
l Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Fractional a => a -> a -> a
/ Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a
sqrt (Fix SRTree
1 Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
+ Fix SRTree
rFix SRTree -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a -> a
*Fix SRTree
r)
     Bin Op
PowerAbs Fix SRTree
l Fix SRTree
r -> Fix SRTree -> Fix SRTree
forall a. Num a => a -> a
abs Fix SRTree
l Fix SRTree -> Fix SRTree -> Fix SRTree
forall a. Floating a => a -> a -> a
** Fix SRTree
r
     Bin Op
op Fix SRTree
l Fix SRTree
r -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
l Fix SRTree
r)

-- | convert a tree into a string in math notation 
--
-- >>> showExpr $ "x0" + sin ( tanh ("t0" + 2) )
-- "(x0 + Sin(Tanh((t0 + 2.0))))"
showExpr :: Fix SRTree -> String
showExpr :: Fix SRTree -> String
showExpr = (SRTree String -> String) -> Fix SRTree -> String
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree String -> String
alg (Fix SRTree -> String)
-> (Fix SRTree -> Fix SRTree) -> Fix SRTree -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
removeProtection
  where alg :: SRTree String -> String
alg = \case
                Var Int
ix     -> Char
'x' Char -> String -> String
forall a. a -> [a] -> [a]
: Int -> String
forall a. Show a => a -> String
show Int
ix
                Param Int
ix   -> Char
't' Char -> String -> String
forall a. a -> [a] -> [a]
: Int -> String
forall a. Show a => a -> String
show Int
ix
                Const Double
c    -> Double -> String
forall a. Show a => a -> String
show Double
c
                Bin Op
op String
l String
r -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"(", String
l, String
" ", Op -> String
showOp Op
op, String
" ", String
r, String
")"]
                Uni Function
f String
t    -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Function -> String
forall a. Show a => a -> String
show Function
f, String
"(", String
t, String
")"]

-- | convert a tree into a string in math notation
-- given named vars.
--
-- >>> showExprWithVar ["mu", "eps"] $ "x0" + sin ( "x1" * tanh ("t0" + 2) )
-- "(mu + Sin(Tanh(eps * (t0 + 2.0))))"
showExprWithVars :: [String] -> Fix SRTree -> String
showExprWithVars :: [String] -> Fix SRTree -> String
showExprWithVars [String]
varnames = (SRTree String -> String) -> Fix SRTree -> String
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree String -> String
alg (Fix SRTree -> String)
-> (Fix SRTree -> Fix SRTree) -> Fix SRTree -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
removeProtection
  where alg :: SRTree String -> String
alg = \case
                Var Int
ix     -> [String]
varnames [String] -> Int -> String
forall a. HasCallStack => [a] -> Int -> a
!! Int
ix
                Param Int
ix   -> Char
't' Char -> String -> String
forall a. a -> [a] -> [a]
: Int -> String
forall a. Show a => a -> String
show Int
ix
                Const Double
c    -> Double -> String
forall a. Show a => a -> String
show Double
c
                Bin Op
op String
l String
r -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"(", String
l, String
" ", Op -> String
showOp Op
op, String
" ", String
r, String
")"]
                Uni Function
f String
t    -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Function -> String
forall a. Show a => a -> String
show Function
f, String
"(", String
t, String
")"]

-- | prints the expression 
printExpr :: Fix SRTree -> IO ()
printExpr :: Fix SRTree -> IO ()
printExpr = String -> IO ()
putStrLn (String -> IO ()) -> (Fix SRTree -> String) -> Fix SRTree -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> String
showExpr 

-- | prints the expression
printExprWithVars :: [String] -> Fix SRTree -> IO ()
printExprWithVars :: [String] -> Fix SRTree -> IO ()
printExprWithVars [String]
varnames = String -> IO ()
putStrLn (String -> IO ()) -> (Fix SRTree -> String) -> Fix SRTree -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> Fix SRTree -> String
showExprWithVars [String]
varnames

-- how to display an operator 
showOp :: Op -> String
showOp :: Op -> String
showOp Op
Add   = String
"+"
showOp Op
Sub   = String
"-"
showOp Op
Mul   = String
"*"
showOp Op
Div   = String
"/"
showOp Op
Power = String
"^"
showOp Op
AQ    = String
"_aq_"
showOp Op
PowerAbs = String
"||^"
{-# INLINE showOp #-}

-- | Displays a tree as a numpy compatible expression.
--
-- >>> showPython $ "x0" + sin ( tanh ("t0" + 2) )
-- "(x[:, 0] + np.sin(np.tanh((t[:, 0] + 2.0))))"
showPython :: Fix SRTree -> String
showPython :: Fix SRTree -> String
showPython = (SRTree String -> String) -> Fix SRTree -> String
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree String -> String
alg (Fix SRTree -> String)
-> (Fix SRTree -> Fix SRTree) -> Fix SRTree -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
removeProtection
  where
    alg :: SRTree String -> String
alg = \case
      Var Int
ix        -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"x[:, ", Int -> String
forall a. Show a => a -> String
show Int
ix, String
"]"]
      Param Int
ix      -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"t[:, ", Int -> String
forall a. Show a => a -> String
show Int
ix, String
"]"]
      Const Double
c       -> Double -> String
forall a. Show a => a -> String
show Double
c
      Bin Op
Power String
l String
r -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
l, String
" ** ", String
r]
      Bin Op
op String
l String
r    -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"(", String
l, String
" ", Op -> String
showOp Op
op, String
" ", String
r, String
")"]
      Uni Function
f String
t       -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Function -> String
forall {a}. IsString a => Function -> a
pyFun Function
f, String
"(", String
t, String
")"]
          

    pyFun :: Function -> a
pyFun Function
Id     = a
""
    pyFun Function
Abs    = a
"np.abs"
    pyFun Function
Sin    = a
"np.sin"
    pyFun Function
Cos    = a
"np.cos"
    pyFun Function
Tan    = a
"np.tan"
    pyFun Function
Sinh   = a
"np.sinh"
    pyFun Function
Cosh   = a
"np.cosh"
    pyFun Function
Tanh   = a
"np.tanh"
    pyFun Function
ASin   = a
"np.asin"
    pyFun Function
ACos   = a
"np.acos"
    pyFun Function
ATan   = a
"np.atan"
    pyFun Function
ASinh  = a
"np.asinh"
    pyFun Function
ACosh  = a
"np.acosh"
    pyFun Function
ATanh  = a
"np.atanh"
    pyFun Function
Sqrt   = a
"np.sqrt"
    pyFun Function
Square = a
"np.square"
    pyFun Function
Log    = a
"np.log"
    pyFun Function
Exp    = a
"np.exp"
    pyFun Function
Cbrt   = a
"np.cbrt"
    pyFun Function
Recip  = a
"np.reciprocal"

-- | print the expression in numpy notation
printPython :: Fix SRTree -> IO ()
printPython :: Fix SRTree -> IO ()
printPython = String -> IO ()
putStrLn (String -> IO ()) -> (Fix SRTree -> String) -> Fix SRTree -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> String
showPython

-- | Displays a tree as a LaTeX compatible expression.
--
-- >>> showLatex $ "x0" + sin ( tanh ("t0" + 2) )
-- "\\left(x_{, 0} + \\operatorname{sin}(\\operatorname{tanh}(\\left(\\theta_{, 0} + 2.0\\right)))\\right)"
showLatex :: Fix SRTree -> String
showLatex :: Fix SRTree -> String
showLatex = (SRTree String -> String) -> Fix SRTree -> String
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree String -> String
alg (Fix SRTree -> String)
-> (Fix SRTree -> Fix SRTree) -> Fix SRTree -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
removeProtection
  where
    alg :: SRTree String -> String
alg = \case
      Var Int
ix        -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"x_{, ", Int -> String
forall a. Show a => a -> String
show Int
ix, String
"}"]
      Param Int
ix      -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"\\theta_{, ", Int -> String
forall a. Show a => a -> String
show Int
ix, String
"}"]
      Const Double
c       -> Double -> String
forall a. Show a => a -> String
show Double
c
      Bin Op
Power String
l String
r -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
l, String
"^{", String
r, String
"}"]
      Bin Op
op String
l String
r    -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"\\left(", String
l, String
" ", Op -> String
showOp Op
op, String
" ", String
r, String
"\\right)"]
      Uni Function
Abs String
t     -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"\\left |", String
t, String
"\\right |"]
      Uni Function
f String
t       -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [Function -> String
showLatexFun Function
f, String
"(", String
t, String
")"]

showLatexFun :: Function -> String
showLatexFun :: Function -> String
showLatexFun Function
f = [String] -> String
forall a. Monoid a => [a] -> a
mconcat [String
"\\operatorname{", (Char -> Char) -> String -> String
forall a b. (a -> b) -> [a] -> [b]
map Char -> Char
toLower (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ Function -> String
forall a. Show a => a -> String
show Function
f, String
"}"]
{-# INLINE showLatexFun #-}

-- | prints expression in LaTeX notation. 
printLatex :: Fix SRTree -> IO ()
printLatex :: Fix SRTree -> IO ()
printLatex = String -> IO ()
putStrLn (String -> IO ()) -> (Fix SRTree -> String) -> Fix SRTree -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> String
showLatex

-- | Displays a tree in Tikz format
showTikz :: Fix SRTree -> String
showTikz :: Fix SRTree -> String
showTikz = (SRTree String -> String) -> Fix SRTree -> String
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree String -> String
alg (Fix SRTree -> String)
-> (Fix SRTree -> Fix SRTree) -> Fix SRTree -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> Fix SRTree
removeProtection
  where
    alg :: SRTree String -> String
alg = \case
      Var Int
ix     -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"[$x_{, ", Int -> String
forall a. Show a => a -> String
show Int
ix, String
"}$]\n"]
      Param Int
ix   -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"[$\\theta_{, ", Int -> String
forall a. Show a => a -> String
show Int
ix, String
"}$]\n"]
      Const Double
c    -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"[$", Double -> String
forall a. Show a => a -> String
show (Integer -> Double -> Double
forall {a} {p}. (RealFrac a, Integral p) => p -> a -> a
roundN Integer
2 Double
c), String
"$]\n"]
      Bin Op
op String
l String
r -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"[", Op -> String
forall {a}. IsString a => Op -> a
showOpTikz Op
op, String
l, String
r, String
"]\n"]
      Uni Function
f String
t    -> [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"[", (Char -> Char) -> String -> String
forall a b. (a -> b) -> [a] -> [b]
map Char -> Char
toLower (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ Function -> String
forall a. Show a => a -> String
show Function
f, String
t, String
"]\n"]

    roundN :: p -> a -> a
roundN p
n a
x = let ten :: a
ten = a
10a -> p -> a
forall a b. (Num a, Integral b) => a -> b -> a
^p
n in (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
ten) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> a) -> (a -> Integer) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Integer
forall b. Integral b => a -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
ten

    showOpTikz :: Op -> a
showOpTikz Op
Add = a
"+\n"
    showOpTikz Op
Sub = a
"-\n"
    showOpTikz Op
Mul = a
"×\n"
    showOpTikz Op
Div = a
"÷\n"
    showOpTikz Op
Power = a
"\\^{}\n"

-- | prints the tree in TikZ format 
printTikz :: Fix SRTree -> IO ()
printTikz :: Fix SRTree -> IO ()
printTikz = String -> IO ()
putStrLn (String -> IO ()) -> (Fix SRTree -> String) -> Fix SRTree -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix SRTree -> String
showTikz