module Numeric.AD.Internal.Classes
    (
    
      Mode(..)
    , one
    
    , Jacobian(..)
    , Primal(..)
    , deriveLifted
    , deriveNumeric
    , Lifted(..)
    , Iso(..)
    ) where
import Control.Applicative
import Data.Char
import Language.Haskell.TH
import Numeric.AD.Internal.Combinators (on)
infixl 8 **!
infixl 7 *!, /!, ^*, *^, ^/
infixl 6 +!, -!, <+>
infix 4 ==!
class Iso a b where
    iso :: f a -> f b
    osi :: f b -> f a
instance Iso a a where
    iso = id
    osi = id
class Lifted t where
    showsPrec1          :: Show a => Int -> t a -> ShowS
    (==!)               :: (Num a, Eq a) => t a -> t a -> Bool
    compare1            :: (Num a, Ord a) => t a -> t a -> Ordering
    fromInteger1        :: Num a => Integer -> t a
    (+!),(-!),(*!)      :: Num a => t a -> t a -> t a
    negate1, abs1, signum1 :: Num a => t a -> t a
    (/!)                :: Fractional a => t a -> t a -> t a
    recip1              :: Fractional a => t a -> t a
    fromRational1       :: Fractional a => Rational -> t a
    toRational1         :: Real a => t a -> Rational 
    pi1                 :: Floating a => t a
    exp1, log1, sqrt1   :: Floating a => t a -> t a
    (**!), logBase1     :: Floating a => t a -> t a -> t a
    sin1, cos1, tan1, asin1, acos1, atan1 :: Floating a => t a -> t a
    sinh1, cosh1, tanh1, asinh1, acosh1, atanh1 :: Floating a => t a -> t a
    properFraction1 :: (RealFrac a, Integral b) => t a -> (b, t a)
    truncate1, round1, ceiling1, floor1 :: (RealFrac a, Integral b) => t a -> b
    floatRadix1     :: RealFloat a => t a -> Integer
    floatDigits1    :: RealFloat a => t a -> Int
    floatRange1     :: RealFloat a => t a -> (Int, Int)
    decodeFloat1    :: RealFloat a => t a -> (Integer, Int)
    encodeFloat1    :: RealFloat a => Integer -> Int -> t a
    exponent1       :: RealFloat a => t a -> Int
    significand1    :: RealFloat a => t a -> t a
    scaleFloat1     :: RealFloat a => Int -> t a -> t a
    isNaN1, isInfinite1, isDenormalized1, isNegativeZero1, isIEEE1 :: RealFloat a => t a -> Bool
    atan21          :: RealFloat a => t a -> t a -> t a
    succ1, pred1    :: (Num a, Enum a) => t a -> t a
    toEnum1         :: (Num a, Enum a) => Int -> t a
    fromEnum1       :: (Num a, Enum a) => t a -> Int
    enumFrom1       :: (Num a, Enum a) => t a -> [t a]
    enumFromThen1   :: (Num a, Enum a) => t a -> t a -> [t a]
    enumFromTo1     :: (Num a, Enum a) => t a -> t a -> [t a]
    enumFromThenTo1 :: (Num a, Enum a) => t a -> t a -> t a -> [t a]
    minBound1       :: (Num a, Bounded a) => t a
    maxBound1       :: (Num a, Bounded a) => t a
class Lifted t => Mode t where
    
    lift  :: Num a => a -> t a
    
    (<+>) :: Num a => t a -> t a -> t a
    
    (*^) :: Num a => a -> t a -> t a
    
    (^*) :: Num a => t a -> a -> t a
    
    (^/) :: Fractional a => t a -> a -> t a
    
    zero :: Num a => t a
    a *^ b = lift a *! b
    a ^* b = a *! lift b
    a ^/ b = a ^* recip b
    zero = lift 0
one :: (Mode t, Num a) => t a
one = lift 1
negOne :: (Mode t, Num a) => t a
negOne = lift (1)
class Primal t where
    primal :: Num a => t a -> a
class (Mode t, Mode (D t)) => Jacobian t where
    type D t :: * -> *
    unary  :: Num a => (a -> a) -> D t a -> t a -> t a
    lift1  :: Num a => (a -> a) -> (D t a -> D t a) -> t a -> t a
    lift1_ :: Num a => (a -> a) -> (D t a -> D t a -> D t a) -> t a -> t a
    binary :: Num a => (a -> a -> a) -> D t a -> D t a -> t a -> t a -> t a
    lift2  :: Num a => (a -> a -> a) -> (D t a -> D t a -> (D t a, D t a)) -> t a -> t a -> t a
    lift2_ :: Num a => (a -> a -> a) -> (D t a -> D t a -> D t a -> (D t a, D t a)) -> t a -> t a -> t a
withPrimal :: (Jacobian t, Num a) => t a -> a -> t a
withPrimal t a = unary (const a) one t
fromBy :: (Jacobian t, Num a) => t a -> t a -> Int -> a -> t a
fromBy a delta n x = binary (\_ _ -> x) one (fromIntegral1 n) a delta
fromIntegral1 :: (Integral n, Lifted t, Num a) => n -> t a
fromIntegral1 = fromInteger1 . fromIntegral
square1 :: (Lifted t, Num a) => t a -> t a
square1 x = x *! x
discrete1 :: (Primal t, Num a) => (a -> c) -> t a -> c
discrete1 f x = f (primal x)
discrete2 :: (Primal t, Num a) => (a -> a -> c) -> t a -> t a -> c
discrete2 f x y = f (primal x) (primal y)
discrete3 :: (Primal t, Num a) => (a -> a -> a -> d) -> t a -> t a -> t a -> d
discrete3 f x y z = f (primal x) (primal y) (primal z)
deriveLifted :: ([Q Pred] -> [Q Pred]) -> Q Type -> Q [Dec]
deriveLifted f _t = do
        [InstanceD cxt0 type0 dec0] <- lifted
        return <$> instanceD (cxt (f (return <$> cxt0))) (return type0) (return <$> dec0)
    where 
      lifted = [d| 
       instance Lifted $_t where
        (==!)         = (==) `on` primal
        compare1      = compare `on` primal
        maxBound1     = lift maxBound
        minBound1     = lift minBound
        showsPrec1    = showsPrec
        fromInteger1  = lift . fromInteger
        (+!)          = (<+>) 
        (-!)          = binary () one negOne 
        (*!)          = lift2 (*) (\x y -> (y, x))
        negate1       = lift1 negate (const negOne)
        abs1          = lift1 abs signum1
        signum1       = lift1 signum (const zero)
        fromRational1 = lift . fromRational
        (/!)          = lift2_ (/) $ \a x y -> (recip1 y, x *! negate1 (square1 a))
        recip1        = lift1_ recip (const . negate1 . square1)
        pi1       = lift pi
        exp1      = lift1_ exp const
        log1      = lift1 log recip1
        logBase1 x y = log1 y /! log1 x
        sqrt1     = lift1_ sqrt (\z _ -> recip1 (lift 2 *! z))
        (**!)     = lift2_ (**) (\z x y -> (y *! z /! x, z *! log1 x)) 
        sin1      = lift1 sin cos1
        cos1      = lift1 cos $ \x -> negate1 (sin1 x)
        tan1 x    = sin1 x /! cos1 x
        asin1     = lift1 asin $ \x -> recip1 (sqrt1 (one -! square1 x))
        acos1     = lift1 acos $ \x -> negate1 (recip1 (sqrt1 (one -! square1 x)))
        atan1     = lift1 atan $ \x -> recip1 (one +! square1 x)
        sinh1     = lift1 sinh cosh1
        cosh1     = lift1 cosh sinh1
        tanh1 x   = sinh1 x /! cosh1 x
        asinh1    = lift1 asinh $ \x -> recip1 (sqrt1 (one +! square1 x))
        acosh1    = lift1 acosh $ \x -> recip1 (sqrt1 (square1 x -! one))
        atanh1    = lift1 atanh $ \x -> recip1 (one -! square1 x)
        succ1                 = lift1 succ (const one)
        pred1                 = lift1 pred (const one)
        toEnum1               = lift . toEnum
        fromEnum1             = discrete1 fromEnum
        enumFrom1 a           = withPrimal a <$> discrete1 enumFrom a
        enumFromTo1 a b       = withPrimal a <$> discrete2 enumFromTo a b
        enumFromThen1 a b     = zipWith (fromBy a delta) [0..] $ discrete2 enumFromThen a b where delta = b -! a
        enumFromThenTo1 a b c = zipWith (fromBy a delta) [0..] $ discrete3 enumFromThenTo a b c where delta = b -! a
        toRational1      = discrete1 toRational
        floatRadix1      = discrete1 floatRadix
        floatDigits1     = discrete1 floatDigits
        floatRange1      = discrete1 floatRange
        decodeFloat1     = discrete1 decodeFloat
        encodeFloat1 m e = lift (encodeFloat m e)
        isNaN1           = discrete1 isNaN
        isInfinite1      = discrete1 isInfinite
        isDenormalized1  = discrete1 isDenormalized
        isNegativeZero1  = discrete1 isNegativeZero
        isIEEE1          = discrete1 isIEEE
        exponent1 = exponent . primal
        scaleFloat1 n = unary (scaleFloat n) (scaleFloat1 n one)
        significand1 x =  unary significand (scaleFloat1 ( floatDigits1 x) one) x
        atan21 = lift2 atan2 $ \vx vy -> let r = recip1 (square1 vx +! square1 vy) in (vy *! r, negate1 vx *! r)
        properFraction1 a = (w, a `withPrimal` pb) where
             pa = primal a
             (w, pb) = properFraction pa
        truncate1 = discrete1 truncate
        round1    = discrete1 round
        ceiling1  = discrete1 ceiling
        floor1    = discrete1 floor |]
varA :: Q Type
varA = varT (mkName "a")
liftedMembers :: Q [String]
liftedMembers = do
    ClassI (ClassD _ _ _ _ ds) <- reify ''Lifted
    return [ nameBase n | SigD n _ <- ds]
deriveNumeric :: ([Q Pred] -> [Q Pred]) -> Q Type -> Q [Dec]
deriveNumeric f t = do
    members <- liftedMembers
    let keep n = nameBase n `elem` members
    xs <- lowerInstance keep ((classP ''Num [varA]:) . f) t `mapM` [''Enum, ''Eq, ''Ord, ''Bounded]
    ys <- lowerInstance keep f                            t `mapM` [''Show, ''Num, ''Fractional, ''Floating, ''RealFloat,''RealFrac, ''Real]
    return (xs ++ ys)
lowerInstance :: (Name -> Bool) -> ([Q Pred] -> [Q Pred]) -> Q Type -> Name -> Q Dec
lowerInstance p f t n = do
    ClassI (ClassD _ _ _ _ ds) <- reify n
    instanceD (cxt (f [classP n [varA]]))
              (conT n `appT` (t `appT` varA))
              (concatMap lower1 ds)
    where
        lower1 :: Dec -> [Q Dec]
        lower1 (SigD n' _) | p n'' = [valD (varP n') (normalB (varE n'')) []] where n'' = primed n'
        lower1 _          = []
        primed n' = mkName $ base ++ [prime]
            where
                base = nameBase n'
                h = head base
                prime | isSymbol h || h `elem` "/*-<>" = '!'
                      | otherwise = '1'