{-# LANGUAGE Rank2Types, TypeFamilies, MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, FlexibleContexts, TemplateHaskell, UndecidableInstances, TypeOperators #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.AD.Internal.Composition
-- Copyright   :  (c) Edward Kmett 2010
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Defines the composition of two AD modes as an AD mode in its own right
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Composition
    ( (:.)(..)
    , On
    , compose
    , decompose
    ) where

import Numeric.AD.Classes
import Numeric.AD.Internal

newtype (f :. g) a = O { runO :: f (AD g a) }


type On f g = g :. f

compose :: AD f (AD g a) -> AD (f :. g) a
compose (AD a) = AD (O a)

decompose :: AD (f :. g) a -> AD f (AD g a)
decompose (AD (O a)) = AD a

instance (Primal f, Mode g, Primal g) => Primal (f :. g) where
    primal = primal . primal . runO

instance (Mode f, Mode g) => Mode (f :. g) where
    lift = O . lift . lift
    O a <+> O b = O (a <+> b) 
    a *^ O b = O (lift a *^ b) 
    O a ^* b = O (a ^* lift b)
    O a ^/ b = O (a ^/ lift b)

instance (Mode f, Mode g) => Lifted (f :. g) where
    showsPrec1 n (O a) = showsPrec1 n a
    O a ==! O b  = a ==! b
    compare1 (O a) (O b) = compare1 a b
    fromInteger1 = O . lift . fromInteger1
    O a +! O b = O (a +! b)
    O a -! O b = O (a -! b)
    O a *! O b = O (a *! b)
    negate1 (O a) = O (negate1 a)
    abs1 (O a) = O (abs1 a)
    signum1 (O a) = O (signum1 a)
    O a /! O b = O (a /! b) 
    recip1 (O a) = O (recip1 a)
    fromRational1 = O . lift . fromRational1
    toRational1 (O a) = toRational1 a
    pi1 = O pi1
    exp1 (O a) = O (exp1 a)
    log1 (O a) = O (log1 a) 
    sqrt1 (O a) = O (sqrt1 a)
    O a **! O b = O (a **! b)
    logBase1 (O a) (O b) = O (logBase1 a b)
    sin1 (O a) = O (sin1 a)
    cos1 (O a) = O (cos1 a)
    tan1 (O a) = O (tan1 a)
    asin1 (O a) = O (asin1 a)
    acos1 (O a) = O (acos1 a)
    atan1 (O a) = O (atan1 a)
    sinh1 (O a) = O (sinh1 a)
    cosh1 (O a) = O (cosh1 a)
    tanh1 (O a) = O (tanh1 a)
    asinh1 (O a) = O (asinh1 a)
    acosh1 (O a) = O (acosh1 a)
    atanh1 (O a) = O (atanh1 a)
    properFraction1 (O a) = (b, O c) where
        (b, c) = properFraction1 a
    truncate1 (O a) = truncate1 a
    round1 (O a) = round1 a
    ceiling1 (O a) = ceiling1 a
    floor1 (O a) = floor1 a
    floatRadix1 (O a) = floatRadix1 a
    floatDigits1 (O a) = floatDigits1 a
    floatRange1 (O a) = floatRange1 a
    decodeFloat1 (O a) = decodeFloat1 a
    encodeFloat1 m e = O (encodeFloat1 m e)
    exponent1 (O a) = exponent1 a
    significand1 (O a) = O (significand1 a)
    scaleFloat1 n (O a) = O (scaleFloat1 n a)
    isNaN1 (O a) = isNaN1 a 
    isInfinite1 (O a) = isInfinite1 a
    isDenormalized1 (O a) = isDenormalized1 a
    isNegativeZero1 (O a) = isNegativeZero1 a
    isIEEE1 (O a) = isIEEE1 a
    atan21 (O a) (O b) = O (atan21 a b)
    succ1 (O a) = O (succ1 a)
    pred1 (O a) = O (pred1 a)
    toEnum1 n = O (toEnum1 n)
    fromEnum1 (O a) = fromEnum1 a
    enumFrom1 (O a) = map O $ enumFrom1 a
    enumFromThen1 (O a) (O b) = map O $ enumFromThen1 a b
    enumFromTo1 (O a) (O b) = map O $ enumFromTo1 a b
    enumFromThenTo1 (O a) (O b) (O c) = map O $ enumFromThenTo1 a b c
    minBound1 = O minBound1
    maxBound1 = O maxBound1

-- deriveNumeric (conT `appT` varT (mkName "f") `appT` varT (mkName "g"))