{-# OPTIONS_GHC -Wall #-} {-# Language GADTs #-} {-# Language FlexibleContexts #-} {-# Language TypeOperators #-} {-# Language TypeFamilies #-} module Dvda.Dual ( Dual(..) , fad , fad' ) where import Data.Ratio ( numerator, denominator ) data Dual a = Dual { dualPrimal :: a , dualPerturbation :: a } deriving (Show, Eq) instance Num a => Num (Dual a) where (Dual x x') * (Dual y y') = Dual (x * y) (x*y' + x'*y) (Dual x x') + (Dual y y') = Dual (x + y) (x' + y') (Dual x x') - (Dual y y') = Dual (x - y) (x' - y') negate (Dual x x') = Dual (-x) (-x') abs (Dual x x') = Dual (abs x) (signum x * x') signum (Dual x _) = Dual (signum x) 0 -- technically this should be a dirac delta fromInteger x = Dual (fromInteger x) 0 instance Fractional a => Fractional (Dual a) where (Dual x x') / (Dual y y') = Dual (x/y) (x'/y - x/(y*y)*y') fromRational x = num/den where num = fromIntegral $ numerator x den = fromIntegral $ denominator x instance Floating a => Floating (Dual a) where pi = Dual pi 0 exp (Dual x x') = Dual (exp x) (exp x *x') sqrt (Dual x x') = Dual (sqrt x) (x'/(2*sqrt x)) log (Dual x x') = Dual (log x) (x'/x) (Dual x x')**(Dual y y') = Dual (x**y) $ ( x'*y + x*y'*log x ) * x**(y-1) logBase (Dual b b') (Dual e e') = Dual primal pert' where primal = logBase b e pert' = (e'/e - primal*b'/b) / log b sin (Dual x x') = Dual (sin x) $ cos x * x' cos (Dual x x') = Dual (cos x) $ -(sin x)*x' tan (Dual x x') = Dual (tan x) $ x'/(cos x * cos x) asin (Dual x x') = Dual (asin x) $ x' / sqrt (1 - x*x) acos (Dual x x') = Dual (acos x) $ -x' / sqrt (1 - x*x) atan (Dual x x') = Dual (atan x) $ x' / (1 + x*x) sinh (Dual x x') = Dual (sinh x) $ cosh x * x' cosh (Dual x x') = Dual (cosh x) $ sinh x * x' tanh (Dual x x') = Dual (tanh x) $ x'/(cosh x * cosh x) asinh (Dual x x') = Dual (asinh x) $ x'/ sqrt (1 + x*x) acosh (Dual x x') = Dual (acosh x) $ x'/( sqrt (x - 1) * sqrt (x + 1) ) atanh (Dual x x') = Dual (atanh x) $ x'/(1 - x*x) -- | Forward derivative propogation. fad' [sin x, 2*x] == [cos x, 2] fad' :: Num a => (Dual a -> [Dual a]) -> a -> [a] fad' f x = map dualPerturbation $ f (Dual x 1) -- | Forward derivative propogation. fad sin x == cos x fad :: Num a => (Dual a -> Dual a) -> a -> a fad f x = dualPerturbation $ f (Dual x 1)