{-# LANGUAGE IncoherentInstances #-} -- | This module provides a category transformer for automatic differentiation. -- -- There are many alternative notions of a generalized derivative. -- Perhaps the most common is the differential Ring. -- In Haskell, this might be defined as: -- -- > class Field r => Differential r where -- > derivative :: r -> r -- > -- > type Diff cat = forall a b. (Category cat, Differential cat a b) -- -- But this runs into problems with the lack of polymorphic constraints in GHC. -- See, for example . -- -- References: -- -- * module SubHask.Category.Trans.Derivative where import SubHask.Algebra import SubHask.Algebra.Vector import SubHask.Category import SubHask.SubType import SubHask.Internal.Prelude import qualified Prelude as P -------------------------------------------------------------------------------- -- | This is essentially just a translation of the "Numeric.AD.Forward.Forward" type -- for use with the SubHask numeric hierarchy. -- -- FIXME: -- -- Add reverse mode auto-differentiation for vectors. -- Apply the "ProofOf" framework from Monotonic data Forward a = Forward { val :: !a , val' :: a } deriving (Typeable,Show) mkMutable [t| forall a. Forward a |] instance Semigroup a => Semigroup (Forward a) where (Forward a1 a1')+(Forward a2 a2') = Forward (a1+a2) (a1'+a2') instance Cancellative a => Cancellative (Forward a) where (Forward a1 a1')-(Forward a2 a2') = Forward (a1-a2) (a1'-a2') instance Monoid a => Monoid (Forward a) where zero = Forward zero zero instance Group a => Group (Forward a) where negate (Forward a b) = Forward (negate a) (negate b) instance Abelian a => Abelian (Forward a) instance Rg a => Rg (Forward a) where (Forward a1 a1')*(Forward a2 a2') = Forward (a1*a2) (a1*a2'+a2*a1') instance Rig a => Rig (Forward a) where one = Forward one zero instance Ring a => Ring (Forward a) where fromInteger x = Forward (fromInteger x) zero instance Field a => Field (Forward a) where reciprocal (Forward a a') = Forward (reciprocal a) (-a'/(a*a)) (Forward a1 a1')/(Forward a2 a2') = Forward (a1/a2) ((a1'*a2+a1*a2')/(a2'*a2')) fromRational r = Forward (fromRational r) 0 --------- proveC1 :: (a ~ (a> (Forward a -> Forward a) -> C1 (a -> a) proveC1 f = Diffn (\a -> val $ f $ Forward a one) $ Diff0 $ \a -> val' $ f $ Forward a one proveC2 :: (a ~ (a> (Forward (Forward a) -> Forward (Forward a)) -> C2 (a -> a) proveC2 f = Diffn (\a -> val $ val $ f $ Forward (Forward a one) one) $ Diffn (\a -> val' $ val $ f $ Forward (Forward a one) one) $ Diff0 (\a -> val' $ val' $ f $ Forward (Forward a one) one) -------------------------------------------------------------------------------- class C (cat :: * -> * -> *) where type D cat :: * -> * -> * derivative :: cat a b -> D cat a (a >< b) data Diff (n::Nat) a b where Diff0 :: (a -> b) -> Diff 0 a b Diffn :: (a -> b) -> Diff (n-1) a (a >< b) -> Diff n a b --------- instance Sup (->) (Diff n) (->) instance Sup (Diff n) (->) (->) instance Diff 0 <: (->) where embedType_ = Embed2 unDiff0 where unDiff0 :: Diff 0 a b -> a -> b unDiff0 (Diff0 f) = f instance Diff n <: (->) where embedType_ = Embed2 unDiffn where unDiffn :: Diff n a b -> a -> b unDiffn (Diffn f f') = f -- -- FIXME: these subtyping instance should be made more generic -- the problem is that type families aren't currently powerful enough -- instance Sup (Diff 0) (Diff 1) (Diff 0) instance Sup (Diff 1) (Diff 0) (Diff 0) instance Diff 1 <: Diff 0 where embedType_ = Embed2 m2n where m2n (Diffn f f') = Diff0 f instance Sup (Diff 0) (Diff 2) (Diff 0) instance Sup (Diff 2) (Diff 0) (Diff 0) instance Diff 2 <: Diff 0 where embedType_ = Embed2 m2n where m2n (Diffn f f') = Diff0 f instance Sup (Diff 1) (Diff 2) (Diff 1) instance Sup (Diff 2) (Diff 1) (Diff 1) instance Diff 2 <: Diff 1 where embedType_ = Embed2 m2n where m2n (Diffn f f') = Diffn f (embedType2 f') --------- instance (1 <= n) => C (Diff n) where type D (Diff n) = Diff (n-1) derivative (Diffn f f') = f' unsafeProveC0 :: (a -> b) -> Diff 0 a b unsafeProveC0 f = Diff0 f unsafeProveC1 :: (a -> b) -- ^ f(x) -> (a -> a> C1 (a -> b) unsafeProveC1 f f' = Diffn f $ unsafeProveC0 f' unsafeProveC2 :: (a -> b) -- ^ f(x) -> (a -> a> (a -> a> C2 (a -> b) unsafeProveC2 f f' f'' = Diffn f $ unsafeProveC1 f' f'' type C0 a = C0_ a type family C0_ (f :: *) :: * where C0_ (a -> b) = Diff 0 a b type C1 a = C1_ a type family C1_ (f :: *) :: * where C1_ (a -> b) = Diff 1 a b type C2 a = C2_ a type family C2_ (f :: *) :: * where C2_ (a -> b) = Diff 2 a b --------------------------------------- -- algebra mkMutable [t| forall n a b. Diff n a b |] instance Semigroup b => Semigroup (Diff 0 a b) where (Diff0 f1 )+(Diff0 f2 ) = Diff0 (f1+f2) instance (Semigroup b, Semigroup (a> Semigroup (Diff 1 a b) where (Diffn f1 f1')+(Diffn f2 f2') = Diffn (f1+f2) (f1'+f2') instance (Semigroup b, Semigroup (a> Semigroup (Diff 2 a b) where (Diffn f1 f1')+(Diffn f2 f2') = Diffn (f1+f2) (f1'+f2') instance Monoid b => Monoid (Diff 0 a b) where zero = Diff0 zero instance (Monoid b, Monoid (a> Monoid (Diff 1 a b) where zero = Diffn zero zero instance (Monoid b, Monoid (a> Monoid (Diff 2 a b) where zero = Diffn zero zero -------------------------------------------------------------------------------- -- test -- v = unsafeToModule [1,2,3,4,5] :: SVector 5 Double -- -- sphere :: Hilbert v => C0 (v -> Scalar v) -- sphere = unsafeProveC0 f -- where -- f v = v<>v