{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE DeriveFunctor         #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- |
-- Module      : Numeric.Backprop.Tuple
-- Copyright   : (c) Justin Le 2018
-- License     : BSD3
--
-- Maintainer  : justin@jle.im
-- Stability   : experimental
-- Portability : non-portable
--
-- Canonical strict tuples with 'Num' instances for usage with /backprop/.
-- This is here to solve the problem of orphan instances in libraries and
-- potential mismatched tuple types.
--
-- If you are writing a library that needs to export 'BVar's of tuples,
-- consider using the tuples in this module so that your library can have
-- easy interoperability with other libraries using /backprop/.
--
-- Because of API decisions, 'backprop' and 'gradBP' only work with things
-- with 'Num' instances.  However, this disallows default 'Prelude' tuples
-- (without orphan instances from packages like
-- <https://hackage.haskell.org/package/NumInstances NumInstances>).
--
-- Until tuples have 'Num' instances in /base/, this module is intended to
-- be a workaround for situations where:
--
-- This comes up often in cases where:
--
--     (1) A function wants to return more than one value (@'BVar' s ('T2'
--     a b)@
--     (2) You want to uncurry a 'BVar' function to use with 'backprop' and
--     'gradBP'.
--     (3) You want to use the useful 'Prism's automatically generated by
--     the lens library, which use tuples for multiple-constructor fields.
--
-- Only 2-tuples and 3-tuples are provided.  Any more and you should
-- probably be using your own custom product types, with instances
-- automatically generated from something like
-- <https://hackage.haskell.org/package/one-liner-instances one-liner-instances>.
--
-- Lenses into the fields are provided, but they also work with '_1', '_2',
-- and '_3' from "Lens.Micro".  However, note that these are incompatible
-- with '_1', '_2', and '_3' from "Control.Lens".
--
-- @since 0.1.1.0
--


module Numeric.Backprop.Tuple (
  -- * Two-tuples
    T2(..)
  -- ** Conversions
  -- $t2iso
  , t2Tup, tupT2
  -- ** Consumption
  , uncurryT2, curryT2
  -- ** Lenses
  , t2_1, t2_2
  -- * Three-tuples
  , T3(..)
  -- ** Conversions
  -- $t3iso
  , t3Tup, tupT3
  -- ** Lenses
  , t3_1, t3_2, t3_3
  -- ** Consumption
  , uncurryT3, curryT3
  ) where

import           Control.DeepSeq
import           Data.Bifunctor
import           Data.Data
import           Data.Semigroup
import           GHC.Generics        (Generic)
import           Lens.Micro
import           Lens.Micro.Internal

-- | Strict 2-tuple with a 'Num' instance.
--
-- @since 0.1.1.0
data T2 a b   = T2 !a !b
  deriving (Show, Read, Eq, Ord, Generic, Functor, Data)

-- | Strict 3-tuple with a 'Num' instance.
--
-- @since 0.1.1.0
data T3 a b c = T3 !a !b !c
  deriving (Show, Read, Eq, Ord, Generic, Functor, Data)

instance (NFData a, NFData b) => NFData (T2 a b)
instance (NFData a, NFData b, NFData c) => NFData (T3 a b c)

instance Bifunctor T2 where
    bimap f g (T2 x y) = T2 (f x) (g y)

instance Bifunctor (T3 a) where
    bimap f g (T3 x y z) = T3 x (f y) (g z)

-- | Convert to a Haskell tuple.
--
-- Forms an isomorphism with 'tupT2'.
t2Tup :: T2 a b -> (a, b)
t2Tup (T2 x y) = (x, y)

-- | Convert from Haskell tuple.
--
-- Forms an isomorphism with 't2Tup'.
tupT2 :: (a, b) -> T2 a b
tupT2 (x, y) = T2 x y

-- | Convert to a Haskell tuple.
--
-- Forms an isomorphism with 'tupT3'.
t3Tup :: T3 a b c -> (a, b, c)
t3Tup (T3 x y z) = (x, y, z)

-- | Convert from Haskell tuple.
--
-- Forms an isomorphism with 't3Tup'.
tupT3 :: (a, b, c) -> T3 a b c
tupT3 (x, y, z) = T3 x y z

-- | Uncurry a function to take in a 'T2' of its arguments
--
-- @since 0.1.2.0
uncurryT2 :: (a -> b -> c) -> T2 a b -> c
uncurryT2 f (T2 x y) = f x y

-- | Curry a function taking a 'T2' of its arguments
--
-- @since 0.1.2.0
curryT2 :: (T2 a b -> c) -> a -> b -> c
curryT2 f x y = f (T2 x y)

-- | Uncurry a function to take in a 'T3' of its arguments
--
-- @since 0.1.2.0
uncurryT3 :: (a -> b -> c -> d) -> T3 a b c -> d
uncurryT3 f (T3 x y z) = f x y z

-- | Curry a function taking a 'T3' of its arguments
--
-- @since 0.1.2.0
curryT3 :: (T3 a b c -> d) -> a -> b -> c -> d
curryT3 f x y z = f (T3 x y z)

instance Field1 (T2 a b) (T2 a' b) a a' where
    _1 f (T2 x y) = (`T2` y) <$> f x

instance Field2 (T2 a b) (T2 a b') b b' where
    _2 f (T2 x y) = T2 x <$> f y

instance Field1 (T3 a b c) (T3 a' b c) a a' where
    _1 f (T3 x y z) = (\x' -> T3 x' y z) <$> f x

instance Field2 (T3 a b c) (T3 a b' c) b b' where
    _2 f (T3 x y z) = (\y' -> T3 x y' z) <$> f y

instance Field3 (T3 a b c) (T3 a b c') c c' where
    _3 f (T3 x y z) = T3 x y <$> f z

-- | Lens into the first field of a 'T2'.  Also exported as '_1' from
-- "Lens.Micro".
t2_1 :: Lens (T2 a b) (T2 a' b) a a'
t2_1 = _1

-- | Lens into the second field of a 'T2'.  Also exported as '_2' from
-- "Lens.Micro".
t2_2 :: Lens (T2 a b) (T2 a b') b b'
t2_2 = _2

-- | Lens into the first field of a 'T3'.  Also exported as '_1' from
-- "Lens.Micro".
t3_1 :: Lens (T3 a b c) (T3 a' b c) a a'
t3_1 = _1

-- | Lens into the second field of a 'T3'.  Also exported as '_2' from
-- "Lens.Micro".
t3_2 :: Lens (T3 a b c) (T3 a b' c) b b'
t3_2 = _2

-- | Lens into the third field of a 'T3'.  Also exported as '_3' from
-- "Lens.Micro".
t3_3 :: Lens (T3 a b c) (T3 a b c') c c'
t3_3 = _3

instance (Num a, Num b) => Num (T2 a b) where
    T2 x1 y1 + T2 x2 y2 = T2 (x1 + x2) (y1 + y2)
    T2 x1 y1 - T2 x2 y2 = T2 (x1 - x2) (y1 - y2)
    T2 x1 y1 * T2 x2 y2 = T2 (x1 * x2) (y1 * y2)
    negate (T2 x y)     = T2 (negate x) (negate y)
    abs    (T2 x y)     = T2 (abs    x) (abs    y)
    signum (T2 x y)     = T2 (signum x) (signum y)
    fromInteger x       = T2 (fromInteger x) (fromInteger x)

instance (Fractional a, Fractional b) => Fractional (T2 a b) where
    T2 x1 y1 / T2 x2 y2 = T2 (x1 / x2) (y1 / y2)
    recip (T2 x y)      = T2 (recip x) (recip y)
    fromRational x      = T2 (fromRational x) (fromRational x)

instance (Floating a, Floating b) => Floating (T2 a b) where
    pi                            = T2 pi pi
    T2 x1 y1 ** T2 x2 y2          = T2 (x1 ** x2) (y1 ** y2)
    logBase (T2 x1 y1) (T2 x2 y2) = T2 (logBase x1 x2) (logBase y1 y2)
    exp   (T2 x y)                = T2 (exp   x) (exp   y)
    log   (T2 x y)                = T2 (log   x) (log   y)
    sqrt  (T2 x y)                = T2 (sqrt  x) (sqrt  y)
    sin   (T2 x y)                = T2 (sin   x) (sin   y)
    cos   (T2 x y)                = T2 (cos   x) (cos   y)
    asin  (T2 x y)                = T2 (asin  x) (asin  y)
    acos  (T2 x y)                = T2 (acos  x) (acos  y)
    atan  (T2 x y)                = T2 (atan  x) (atan  y)
    sinh  (T2 x y)                = T2 (sinh  x) (sinh  y)
    cosh  (T2 x y)                = T2 (cosh  x) (cosh  y)
    asinh (T2 x y)                = T2 (asinh x) (asinh y)
    acosh (T2 x y)                = T2 (acosh x) (acosh y)
    atanh (T2 x y)                = T2 (atanh x) (atanh y)

instance (Semigroup a, Semigroup b) => Semigroup (T2 a b) where
    T2 x1 y1 <> T2 x2 y2 = T2 (x1 <> x2) (y1 <> y2)

instance (Monoid a, Monoid b) => Monoid (T2 a b) where
    mappend (T2 x1 y1) (T2 x2 y2) = T2 (mappend x1 x2) (mappend y1 y2)
    mempty                        = T2 mempty mempty

instance (Num a, Num b, Num c) => Num (T3 a b c) where
    T3 x1 y1 z1 + T3 x2 y2 z2 = T3 (x1 + x2) (y1 + y2) (z1 + z2)
    T3 x1 y1 z1 - T3 x2 y2 z2 = T3 (x1 - x2) (y1 - y2) (z1 + z2)
    T3 x1 y1 z1 * T3 x2 y2 z2 = T3 (x1 * x2) (y1 * y2) (z1 + z2)
    negate (T3 x y z)         = T3 (negate x) (negate y) (negate z)
    abs    (T3 x y z)         = T3 (abs    x) (abs    y) (abs    z)
    signum (T3 x y z)         = T3 (signum x) (signum y) (signum z)
    fromInteger x             = T3 (fromInteger x) (fromInteger x) (fromInteger x)

instance (Fractional a, Fractional b, Fractional c) => Fractional (T3 a b c) where
    T3 x1 y1 z1 / T3 x2 y2 z2 = T3 (x1 / x2) (y1 / y2) (z1 / z2)
    recip (T3 x y z)          = T3 (recip x) (recip y) (recip z)
    fromRational x            = T3 (fromRational x) (fromRational x) (fromRational x)

instance (Floating a, Floating b, Floating c) => Floating (T3 a b c) where
    pi                                  = T3 pi pi pi
    T3 x1 y1 z1 ** T3 x2 y2 z2          = T3 (x1 ** x2) (y1 ** y2) (z1 ** z2)
    logBase (T3 x1 y1 z1) (T3 x2 y2 z2) = T3 (logBase x1 x2) (logBase y1 y2) (logBase z1 z2)
    exp   (T3 x y z)                    = T3 (exp   x) (exp   y) (exp   z)
    log   (T3 x y z)                    = T3 (log   x) (log   y) (log   z)
    sqrt  (T3 x y z)                    = T3 (sqrt  x) (sqrt  y) (sqrt  z)
    sin   (T3 x y z)                    = T3 (sin   x) (sin   y) (sin   z)
    cos   (T3 x y z)                    = T3 (cos   x) (cos   y) (cos   z)
    asin  (T3 x y z)                    = T3 (asin  x) (asin  y) (asin  z)
    acos  (T3 x y z)                    = T3 (acos  x) (acos  y) (acos  z)
    atan  (T3 x y z)                    = T3 (atan  x) (atan  y) (atan  z)
    sinh  (T3 x y z)                    = T3 (sinh  x) (sinh  y) (sinh  z)
    cosh  (T3 x y z)                    = T3 (cosh  x) (cosh  y) (cosh  z)
    asinh (T3 x y z)                    = T3 (asinh x) (asinh y) (asinh z)
    acosh (T3 x y z)                    = T3 (acosh x) (acosh y) (acosh z)
    atanh (T3 x y z)                    = T3 (atanh x) (atanh y) (atanh z)

instance (Semigroup a, Semigroup b, Semigroup c) => Semigroup (T3 a b c) where
    T3 x1 y1 z1 <> T3 x2 y2 z2 = T3 (x1 <> x2) (y1 <> y2) (z1 <> z2)

instance (Monoid a, Monoid b, Monoid c) => Monoid (T3 a b c) where
    mappend (T3 x1 y1 z1) (T3 x2 y2 z2) = T3 (mappend x1 x2) (mappend y1 y2) (mappend z1 z2)
    mempty                              = T3 mempty mempty mempty

-- $t2iso
--
-- If using /lens/, the two conversion functions can be chained with prisms
-- and traversals and other optics using:
--
-- @
-- 'iso' 'tupT2' 't2Tup' :: 'Iso'' (a, b) ('T2' a b)
-- @

-- $t3iso
--
-- If using /lens/, the two conversion functions can be chained with prisms
-- and traversals and other optics using:
--
-- @
-- 'iso' 'tupT3' 't2Tup' :: 'Iso'' (a, b, c) ('T3' a b c)
-- @