{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses #-}
module Data.Ring.Module.AutomaticDifferentiation 
    ( module Data.Ring.Module
    , D
    ) where

import Prelude hiding ((*),(+),(-),subtract,negate)
import Data.Ring.Sugar
import Data.Ring.Module
import Data.Monoid.Reducer
import Test.QuickCheck
import Control.Monad

data D r m = D r m

instance (Monoid r, Monoid m) => Monoid (D r m) where
    mempty = D mempty mempty
    D x m `mappend` D y n = D (x + y) (m + n)

instance (Module r m) => Multiplicative (D r m) where
    one = D one zero
    D x m `times` D y n = D (x * y) (x *. n + m .* y)

instance (Group r, Module r m, Group m) => Group (D r m) where
    gnegate (D x m) = D (gnegate x) (gnegate m)
    D x m `minus` D y n = D (x `minus` y) (m `minus` n)
    D x m `gsubtract` D y n = D (x `gsubtract` y) (m `gsubtract` n)

instance (LeftSemiNearRing r, Module r m) => LeftSemiNearRing (D r m)
instance (RightSemiNearRing r, Module r m) => RightSemiNearRing (D r m)
instance (SemiRing r, Module r m) => SemiRing (D r m)
instance (Ring r, Module r m, Group m) => Ring (D r m)

instance (c `Reducer` r, c `Reducer` m) => Reducer c (D r m) where
    unit c = D (unit c) (unit c)
    c `cons` D x m = D (c `cons` x) (c `cons` m)
    D x m `snoc` c = D (x `snoc` c) (m `snoc` c)

instance (Arbitrary r, Arbitrary m) => Arbitrary (D r m) where
    arbitrary = liftM2 D arbitrary arbitrary
    shrink (D r m) = liftM2 D (shrink r) (shrink m)

instance (CoArbitrary r, CoArbitrary m) => CoArbitrary (D r m) where
    coarbitrary (D r m) = coarbitrary r >< coarbitrary m

{--
infix 0 ><

(><) :: Multiplicatve a => (a -> a) -> (AD a -> AD a) -> AD a -> AD a
(f >< f') a@(AD a0 a') = D (f a0) (a' * f' a)

data AD r = AD r (Maybe (AD r))

instance (Monoid r) => Monoid (AD r) where
    mempty = K mempty
    AD x m + AD y n = D (x + y) (m + n)

instance (c `Reducer` r) => Reducer c (AD r) where
    unit c = c' where c' = AD (unit c) c'
--}