module Algebra.AD (D(..),E(..),subst,dVar,var,sqrtE) where
import Algebra.Classes hiding ((:+))
import Data.Map (Map)
import qualified Data.Map.Strict as M
import Prelude hiding (Num(..),(/),fromRational,recip)
import Data.Function (on)
data AST v c = V v
| AST v c :* AST v c
| AST v c :+ AST v c
| AST v c :- AST v c
| K c
instance (Show v, Show c) => Show (AST v c) where
showsPrec p (V v) = shows v
showsPrec p (K c ) = shows c
showsPrec p (x :+ y) = parens (p>2) (showsPrec 2 x . showString " + " . showsPrec 2 y)
showsPrec p (x :* y) = parens (p>3) (showsPrec 3 x . showString " + " . showsPrec 3 y)
parens True x = showString "(" . x . showString ")"
parens False x = x
data D v c = D {dValue :: !c
,dDerivs :: !(Map v c)
}
deriving Show
dVar :: forall v c. Ring c => v -> c -> D v c
dVar v c = D c (M.singleton v 1)
var :: (Multiplicative c, Additive c, Ord v) => v -> E v c
var v = E $ \env -> env v
instance (Ord v,Additive c) => Additive (D v c) where
zero = D zero zero
D v1 d1 + D v2 d2 = D (v1 + v2) (d1 + d2)
instance (Ord v,Group c) => Group (D v c) where
negate (D x d) = D (negate x) (negate d)
D v1 d1 D v2 d2 = D (v1 v2) (d1 d2)
instance Ord c => Ord (D v c) where
compare = compare `on` dValue
instance Eq c => Eq (D v c) where
(==) = (==) `on` dValue
instance (Ord v,Ring c) => Multiplicative (D v c) where
one = D one zero
D v1 d1 * D v2 d2 = D (v1 * v2) (v2 *^ d1 + v1 *^ d2)
instance (AbelianAdditive c,Ord v) => AbelianAdditive (D v c)
instance (Ord v,Ring c) => Module c (D v c) where
k *^ D v d = D (k * v) (k *^ d)
instance (Ord v,Ring c) => Module (D v c) (D v c) where
(*^) = (*)
instance (Ord v, Ring c) => Ring (D v c) where
fromInteger k = D (fromInteger k) zero
newtype E v c = E {fromE :: (v -> D v c) -> D v c}
instance (Ord v,Additive c) => Additive (E v c) where
zero = E (const zero)
(+) = liftE2 (+)
instance (Ord v,Group c) => Group (E v c) where
negate (E x) = E (negate . x)
() = liftE2 ()
instance (Ord v,Ring c) => Multiplicative (E v c) where
one = E (const one)
(*) = liftE2 (*)
instance (Ord v,Ring c) => AbelianAdditive (E v c)
instance (Ord v,Ring c) => Module c (E v c) where
k *^ E x = E ((k *^) . x)
instance (Ord v,Ring c) => Module (E v c) (E v c) where
(*^) = (*)
instance (Ord v, Ring c) => Ring (E v c) where
fromInteger k = E (\ _ -> fromInteger k)
liftE2 :: forall t t1. (D t t1 -> D t t1 -> D t t1) -> E t t1 -> E t t1 -> E t t1
liftE2 f (E x) (E y) = E (\e -> f (x e) (y e))
liftE :: forall t t1. (D t t1 -> D t t1) -> E t t1 -> E t t1
liftE f (E x) = E (\e -> f (x e))
subst :: E v c -> (v -> E v c) -> E v c
subst (E p) f = E $ \k -> p (\a -> fromE (f a) k)
sqrtD :: (Ord v, Floating c, Field c) => D v c -> D v c
sqrtD (D v d) = D (sqrtv) ((0.5/sqrtv) *^ d)
where sqrtv = sqrt v
sqrtE :: forall t t1. (Floating t1, Ord t, Field t1) => E t t1 -> E t t1
sqrtE = liftE sqrtD
instance (Field c,Ord v) => Division (D v c) where
recip (D v d) = D (recip v) (negate (square iv) *^ d)
where square x = x*x
iv = recip v
instance (Field c,Ord v) => Division (E v c) where
recip = liftE recip
(/) = liftE2 (/)