module Numeric.AD.Internal.Forward.Double
  ( ForwardDouble(..)
  , bundle
  , unbundle
  , apply
  , bind
  , bind'
  , bindWith
  , bindWith'
  , transposeWith
  ) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative hiding ((<**>))
import Data.Foldable (Foldable, toList)
import Data.Traversable (Traversable, mapAccumL)
#else
import Data.Foldable (toList)
import Data.Traversable (mapAccumL)
#endif
import Control.Monad (join)
import Data.Function (on)
import Data.Number.Erf
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
data ForwardDouble = ForwardDouble { primal, tangent ::  !Double }
  deriving (Read, Show)
unbundle :: ForwardDouble -> (Double, Double)
unbundle (ForwardDouble a da) = (a, da)
bundle :: Double -> Double -> ForwardDouble
bundle = ForwardDouble
apply :: (ForwardDouble -> b) -> Double -> b
apply f a = f (bundle a 1)
instance Mode ForwardDouble where
  type Scalar ForwardDouble = Double
  auto = flip ForwardDouble 0
  zero = ForwardDouble 0 0
  isKnownZero (ForwardDouble 0 0) = True
  isKnownZero _ = False
  isKnownConstant (ForwardDouble _ 0) = True
  isKnownConstant _ = False
  a *^ ForwardDouble b db = ForwardDouble (a * b) (a * db)
  ForwardDouble a da ^* b = ForwardDouble (a * b) (da * b)
  ForwardDouble a da ^/ b = ForwardDouble (a / b) (da / b)
(<+>) :: ForwardDouble -> ForwardDouble -> ForwardDouble
ForwardDouble a da <+> ForwardDouble b db = ForwardDouble (a + b) (da + db)
instance Jacobian ForwardDouble where
  type D ForwardDouble = Id Double
  unary f (Id dadb) (ForwardDouble b db) = ForwardDouble (f b) (dadb * db)
  lift1 f df (ForwardDouble b db) = ForwardDouble (f b) (dadb * db) where
    Id dadb = df (Id b)
  lift1_ f df (ForwardDouble b db) = ForwardDouble a da where
    a = f b
    Id da = df (Id a) (Id b) ^* db
  binary f (Id dadb) (Id dadc) (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble (f b c) $ dadb * db + dc * dadc
  lift2 f df (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble a da where
    a = f b c
    (Id dadb, Id dadc) = df (Id b) (Id c)
    da = dadb * db + dc * dadc
  lift2_ f df (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble a da where
    a = f b c
    (Id dadb, Id dadc) = df (Id a) (Id b) (Id c)
    da = dadb * db + dc * dadc
instance Eq ForwardDouble where
  (==)          = on (==) primal
instance Ord ForwardDouble where
  compare       = on compare primal
instance Num ForwardDouble where
  fromInteger 0  = zero
  fromInteger n = auto (fromInteger n)
  (+)          = (<+>) 
  ()          = binary () (auto 1) (auto (1)) 
  (*)          = lift2 (*) (\x y -> (y, x))
  negate       = lift1 negate (const (auto (1)))
  abs          = lift1 abs signum
  signum a     = lift1 signum (const zero) a
instance Fractional ForwardDouble where
  fromRational 0 = zero
  fromRational r = auto (fromRational r)
  x / y        = x * recip y
  recip        = lift1_ recip (const . negate . join (*))
instance Floating ForwardDouble where
  pi       = auto pi
  exp      = lift1_ exp const
  log      = lift1 log recip
  logBase x y = log y / log x
  sqrt     = lift1_ sqrt (\z _ -> recip (auto 2 * z))
  ForwardDouble 0 0 ** ForwardDouble a _ = ForwardDouble (0 ** a) 0
  _ ** ForwardDouble 0 0                 = ForwardDouble 1 0
  x ** ForwardDouble y 0 = lift1 (**y) (\z -> y *^ z ** Id (y  1)) x
  x ** y                 = lift2_ (**) (\z xi yi -> (yi * z / xi, z * log xi)) x y
  sin      = lift1 sin cos
  cos      = lift1 cos $ negate . sin
  tan      = lift1 tan $ recip . join (*) . cos
  asin     = lift1 asin $ \x -> recip (sqrt (auto 1  join (*) x))
  acos     = lift1 acos $ \x -> negate (recip (sqrt (1  join (*) x)))
  atan     = lift1 atan $ \x -> recip (1 + join (*) x)
  sinh     = lift1 sinh cosh
  cosh     = lift1 cosh sinh
  tanh     = lift1 tanh $ recip . join (*) . cosh
  asinh    = lift1 asinh $ \x -> recip (sqrt (1 + join (*) x))
  acosh    = lift1 acosh $ \x -> recip (sqrt (join (*) x  1))
  atanh    = lift1 atanh $ \x -> recip (1  join (*) x)
instance Enum ForwardDouble where
  succ                 = lift1 succ (const 1)
  pred                 = lift1 pred (const 1)
  toEnum               = auto . toEnum
  fromEnum             = fromEnum . primal
  enumFrom a           = withPrimal a <$> enumFrom (primal a)
  enumFromTo a b       = withPrimal a <$> enumFromTo (primal a) (primal b)
  enumFromThen a b     = zipWith (fromBy a delta) [0..] $ enumFromThen (primal a) (primal b) where delta = b  a
  enumFromThenTo a b c = zipWith (fromBy a delta) [0..] $ enumFromThenTo (primal a) (primal b) (primal c) where delta = b  a
instance Real ForwardDouble where
  toRational      = toRational . primal
instance RealFloat ForwardDouble where
  floatRadix      = floatRadix . primal
  floatDigits     = floatDigits . primal
  floatRange      = floatRange . primal
  decodeFloat     = decodeFloat . primal
  encodeFloat m e = auto (encodeFloat m e)
  isNaN           = isNaN . primal
  isInfinite      = isInfinite . primal
  isDenormalized  = isDenormalized . primal
  isNegativeZero  = isNegativeZero . primal
  isIEEE          = isIEEE . primal
  exponent = exponent
  scaleFloat n = unary (scaleFloat n) (scaleFloat n 1)
  significand x =  unary significand (scaleFloat ( floatDigits x) 1) x
  atan2 = lift2 atan2 $ \vx vy -> let r = recip (join (*) vx + join (*) vy) in (vy * r, negate vx * r)
instance RealFrac ForwardDouble where
  properFraction a = (w, a `withPrimal` pb) where
    pa = primal a
    (w, pb) = properFraction pa
  truncate = truncate . primal
  round    = round . primal
  ceiling  = ceiling . primal
  floor    = floor . primal
instance Erf ForwardDouble where
  erf = lift1 erf $ \x -> (2 / sqrt pi) * exp (negate x * x)
  erfc = lift1 erfc $ \x -> ((2) / sqrt pi) * exp (negate x * x)
  normcdf = lift1 normcdf $ \x -> ((1) / sqrt pi) * exp (x * x * fromRational ( recip 2) / sqrt 2)
instance InvErf ForwardDouble where
  inverf = lift1 inverfc $ \x -> recip $ (2 / sqrt pi) * exp (negate x * x)
  inverfc = lift1 inverfc $ \x -> recip $ negate (2 / sqrt pi) * exp (negate x * x)
  invnormcdf = lift1 invnormcdf $ \x -> recip $ ((1) / sqrt pi) * exp (x * x * fromRational ( recip 2) / sqrt 2)
bind :: Traversable f => (f ForwardDouble -> b) -> f Double -> f b
bind f as = snd $ mapAccumL outer (0 :: Int) as where
  outer !i _ = (i + 1, f $ snd $ mapAccumL (inner i) 0 as)
  inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a)
bind' :: Traversable f => (f ForwardDouble -> b) -> f Double -> (b, f b)
bind' f as = dropIx $ mapAccumL outer (0 :: Int, b0) as where
  outer (!i, _) _ = let b = f $ snd $ mapAccumL (inner i) (0 :: Int) as in ((i + 1, b), b)
  inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a)
  b0 = f (auto <$> as)
  dropIx ((_,b),bs) = (b,bs)
bindWith :: Traversable f => (Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> f c
bindWith g f as = snd $ mapAccumL outer (0 :: Int) as where
  outer !i a = (i + 1, g a $ f $ snd $ mapAccumL (inner i) 0 as)
  inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a)
bindWith' :: Traversable f => (Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> (b, f c)
bindWith' g f as = dropIx $ mapAccumL outer (0 :: Int, b0) as where
  outer (!i, _) a = let b = f $ snd $ mapAccumL (inner i) (0 :: Int) as in ((i + 1, b), g a b)
  inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a)
  b0 = f (auto <$> as)
  dropIx ((_,b),bs) = (b,bs)
transposeWith :: (Functor f, Foldable f, Traversable g) => (b -> f a -> c) -> f (g a) -> g b -> g c
transposeWith f as = snd . mapAccumL go xss0 where
  go xss b = (tail <$> xss, f b (head <$> xss))
  xss0 = toList <$> as