{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, RankNTypes, FunctionalDependencies, UndecidableInstances, FlexibleContexts #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Ring.Module.AutomaticDifferentiation
-- Copyright   :  (c) Edward Kmett 2009
-- License     :  BSD-style
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  portable (instances use MPTCs)
--
-----------------------------------------------------------------------------

module Data.Ring.Module.AutomaticDifferentiation 
    ( module Data.Ring.Module
    , D
    , d
    , lift
    ) where

import Prelude
import Data.Ring.Module
import Data.Monoid.Reducer
import Test.QuickCheck
import Control.Monad

data D s r m = D r m deriving (Show,Read)

lift :: Monoid m => r -> D s r m
lift x = D x zero

infinitesimal :: (Monoid r, Multiplicative m) => D s r m
infinitesimal = D zero one

instance Eq r => Eq (D s r m) where
    D x _ == D y _ = x == y

instance Ord r => Ord (D s r m) where
    D x _ `compare` D y _ = compare x y

instance (Monoid r, Monoid m) => Monoid (D s r m) where
    mempty = D mempty mempty
    D x m `mappend` D y n = D (x `mappend` y) (m `mappend` n)

instance (r `Module` m) => Multiplicative (D s r m) where
    one = D one zero
    D x m `times` D y n = D (x `times` y) (x *. n `plus` m .* y)

instance (Group r, r `Module` m, Group m) => Group (D s r m) where
    gnegate (D x m) = D (gnegate x) (gnegate m)
    D x m `minus` D y n = D (x `minus` y) (m `minus` n)
    D x m `gsubtract` D y n = D (x `gsubtract` y) (m `gsubtract` n)

instance Num a => Num (D s a a) where
    D x x' + D y y' = D (x + y) (x' + y')
    D x x' * D y y' = D (x * y) (x * y' + x' * y)
    D x x' - D y y' = D (x - y) (x' - y')
    negate (D x x') = D (negate x) (negate x')
    abs (D x x') = D (abs x) (signum x * x')
    signum (D x _) = D (signum x) 0
    fromInteger x = D (fromInteger x) 0

instance Fractional a => Fractional (D s a a) where
    recip (D x x') = D (recip x) (-x'/x/x)
    fromRational x = D (fromRational x) 0

instance (LeftSemiNearRing r, Module r m) => LeftSemiNearRing (D s r m)
instance (RightSemiNearRing r, Module r m) => RightSemiNearRing (D s r m)
instance (SemiRing r, Module r m) => SemiRing (D s r m)
instance (Ring r, Module r m, Group m) => Ring (D s r m)

instance (c `Reducer` r, c `Reducer` m) => Reducer c (D s r m) where
    unit c = D (unit c) (unit c)
    c `cons` D x m = D (c `cons` x) (c `cons` m)
    D x m `snoc` c = D (x `snoc` c) (m `snoc` c)

instance (Arbitrary r, Arbitrary m) => Arbitrary (D s r m) where
    arbitrary = liftM2 D arbitrary arbitrary
    shrink (D r m) = liftM2 D (shrink r) (shrink m)

instance (CoArbitrary r, CoArbitrary m) => CoArbitrary (D s r m) where
    coarbitrary (D r m) = coarbitrary r >< coarbitrary m

d :: (Monoid r, Multiplicative m) => (forall s. D s r m -> D s r m) -> (r,m)
d f = (y,y') where D y y' = f infinitesimal