{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE QuantifiedConstraints  #-}
{-# LANGUAGE RecordWildCards        #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE UndecidableInstances   #-}
module Q.Stochastic.Discretize
        where

import           Data.Functor
import           Data.RVar
import           Numeric.LinearAlgebra
import           Q.Stochastic.Process
-- |Euler discretization of stochastic processes
newtype Euler = Euler { Euler -> Double
eDt :: Double }
        deriving (Int -> Euler -> ShowS
[Euler] -> ShowS
Euler -> String
(Int -> Euler -> ShowS)
-> (Euler -> String) -> ([Euler] -> ShowS) -> Show Euler
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Euler] -> ShowS
$cshowList :: [Euler] -> ShowS
show :: Euler -> String
$cshow :: Euler -> String
showsPrec :: Int -> Euler -> ShowS
$cshowsPrec :: Int -> Euler -> ShowS
Show, Euler -> Euler -> Bool
(Euler -> Euler -> Bool) -> (Euler -> Euler -> Bool) -> Eq Euler
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Euler -> Euler -> Bool
$c/= :: Euler -> Euler -> Bool
== :: Euler -> Euler -> Bool
$c== :: Euler -> Euler -> Bool
Eq)

-- | Euler end-point discretization of stochastic processes
newtype EndEuler = EndEuler { EndEuler -> Double
eeDt :: Double }
        deriving (Int -> EndEuler -> ShowS
[EndEuler] -> ShowS
EndEuler -> String
(Int -> EndEuler -> ShowS)
-> (EndEuler -> String) -> ([EndEuler] -> ShowS) -> Show EndEuler
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EndEuler] -> ShowS
$cshowList :: [EndEuler] -> ShowS
show :: EndEuler -> String
$cshow :: EndEuler -> String
showsPrec :: Int -> EndEuler -> ShowS
$cshowsPrec :: Int -> EndEuler -> ShowS
Show, EndEuler -> EndEuler -> Bool
(EndEuler -> EndEuler -> Bool)
-> (EndEuler -> EndEuler -> Bool) -> Eq EndEuler
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EndEuler -> EndEuler -> Bool
$c/= :: EndEuler -> EndEuler -> Bool
== :: EndEuler -> EndEuler -> Bool
$c== :: EndEuler -> EndEuler -> Bool
Eq)


instance Discretize Euler Double where
  dDrift :: a -> Euler -> (Double, Double) -> RVar Double
dDrift a
p Euler{Double
eDt :: Double
eDt :: Euler -> Double
..} (Double, Double)
s0 = a -> (Double, Double) -> RVar Double
forall a b. StochasticProcess a b => a -> (Double, b) -> RVar b
pDrift a
p (Double, Double)
s0 RVar Double -> (Double -> Double) -> RVar Double
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
eDt)
  dDiff :: a -> Euler -> (Double, Double) -> RVar Double
dDiff  a
p Euler{Double
eDt :: Double
eDt :: Euler -> Double
..} (Double, Double)
b  = (a -> (Double, Double) -> RVar Double
forall a b. StochasticProcess a b => a -> (Double, b) -> RVar b
pDiff a
p (Double, Double)
b) RVar Double -> (Double -> Double) -> RVar Double
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Floating a => a -> a
sqrt Double
eDt))
  dDt :: a -> Euler -> (Double, Double) -> Double
dDt    a
_ Euler{Double
eDt :: Double
eDt :: Euler -> Double
..} (Double, Double)
_  = Double
eDt

instance Discretize Euler (Vector Double) where
  dDrift :: a -> Euler -> (Double, Vector Double) -> RVar (Vector Double)
dDrift a
p Euler{Double
eDt :: Double
eDt :: Euler -> Double
..} (Double, Vector Double)
s0 = a -> (Double, Vector Double) -> RVar (Vector Double)
forall a b. StochasticProcess a b => a -> (Double, b) -> RVar b
pDrift a
p (Double, Vector Double)
s0 RVar (Vector Double)
-> (Vector Double -> Vector Double) -> RVar (Vector Double)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Double -> Vector Double -> Vector Double
forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale Double
eDt)
  dDiff :: a -> Euler -> (Double, Vector Double) -> RVar (Vector Double)
dDiff  a
p Euler{Double
eDt :: Double
eDt :: Euler -> Double
..} (Double, Vector Double)
b = (a -> (Double, Vector Double) -> RVar (Vector Double)
forall a b. StochasticProcess a b => a -> (Double, b) -> RVar b
pDiff a
p (Double, Vector Double)
b) RVar (Vector Double)
-> (Vector Double -> Vector Double) -> RVar (Vector Double)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Double -> Vector Double -> Vector Double
forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Double -> Double
forall a. Floating a => a -> a
sqrt Double
eDt))
  dDt :: a -> Euler -> (Double, Vector Double) -> Double
dDt    a
_ Euler{Double
eDt :: Double
eDt :: Euler -> Double
..} (Double, Vector Double)
_  = Double
eDt

instance (forall a b. StochasticProcess a Double) => Discretize EndEuler Double where
  dDrift :: a -> EndEuler -> (Double, Double) -> RVar Double
dDrift a
p EndEuler{Double
eeDt :: Double
eeDt :: EndEuler -> Double
..} s0 :: (Double, Double)
s0@(Double
t0, Double
x0) = a -> (Double, Double) -> RVar Double
forall a b. StochasticProcess a b => a -> (Double, b) -> RVar b
pDrift a
p (Double
t0 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
eeDt, Double
x0) RVar Double -> (Double -> Double) -> RVar Double
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
eeDt)
  dDiff :: a -> EndEuler -> (Double, Double) -> RVar Double
dDiff  a
p EndEuler{Double
eeDt :: Double
eeDt :: EndEuler -> Double
..}  s0 :: (Double, Double)
s0@(Double
t0, Double
x0) =  a -> (Double, Double) -> RVar Double
forall a b. StochasticProcess a b => a -> (Double, b) -> RVar b
pDiff  a
p (Double
t0 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
eeDt, Double
x0) RVar Double -> (Double -> Double) -> RVar Double
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Floating a => a -> a
sqrt Double
eeDt))
  dDt :: a -> EndEuler -> (Double, Double) -> Double
dDt    a
_ EndEuler
e (Double, Double)
_   = EndEuler -> Double
eeDt EndEuler
e