{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE CPP                   #-}
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE DeriveFunctor         #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeInType            #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ViewPatterns          #-}

-- |
-- Module      : Numeric.Backprop.Tuple
-- Copyright   : (c) Justin Le 2018
-- License     : BSD3
--
-- Maintainer  : justin@jle.im
-- Stability   : experimental
-- Portability : non-portable
--
-- Canonical strict tuples (and unit) with 'Num' instances for usage with
-- /backprop/. This is here to solve the problem of orphan instances in
-- libraries and potential mismatched tuple types.
--
-- If you are writing a library that needs to export 'BVar's of tuples,
-- consider using the tuples in this module so that your library can have
-- easy interoperability with other libraries using /backprop/.
--
-- Because of API decisions, 'backprop' and 'gradBP' only work with things
-- with 'Num' instances.  However, this disallows default 'Prelude' tuples
-- (without orphan instances from packages like
-- <https://hackage.haskell.org/package/NumInstances NumInstances>).
--
-- Until tuples have 'Num' instances in /base/, this module is intended to
-- be a workaround for situations where:
--
-- This comes up often in cases where:
--
--     (1) A function wants to return more than one value (@'BVar' s ('T2'
--     a b)@
--     (2) You want to uncurry a 'BVar' function to use with 'backprop' and
--     'gradBP'.
--     (3) You want to use the useful 'Prism's automatically generated by
--     the lens library, which use tuples for multiple-constructor fields.
--
-- Only 2-tuples and 3-tuples are provided.  Any more and you should
-- probably be using your own custom product types, with instances
-- automatically generated from something like
-- <https://hackage.haskell.org/package/one-liner-instances one-liner-instances>.
--
-- Lenses into the fields are provided, but they also work with '_1', '_2',
-- and '_3' from "Lens.Micro".  However, note that these are incompatible
-- with '_1', '_2', and '_3' from "Control.Lens".
--
-- You can "construct" a @'BVar' s ('T2' a b)@ with functions like
-- 'isoVar'.
--
-- @since 0.1.1.0
--


module Numeric.Backprop.Tuple (
  -- * Zero-tuples (unit)
    T0(..)
  -- * Two-tuples
  , T2(..)
  -- ** Conversions
  -- $t2iso
  , t2Tup, tupT2
  -- ** Consumption
  , uncurryT2, curryT2
  -- ** Lenses
  , t2_1, t2_2
  -- * Three-tuples
  , T3(..)
  -- ** Conversions
  -- $t3iso
  , t3Tup, tupT3
  -- ** Lenses
  , t3_1, t3_2, t3_3
  -- ** Consumption
  , uncurryT3, curryT3
  -- * N-Tuples
  , T(..)
  , indexT
  -- ** Conversions
  -- $tiso
  , tOnly, onlyT, tSplit, tAppend, tProd, prodT
  -- ** Lenses
  , tIx, tHead, tTail, tTake, tDrop
  -- ** Internal Utility
  , constT, mapT, zipT
  ) where

import           Control.DeepSeq
import           Data.Bifunctor
import           Data.Data
import           Data.Kind
import           Data.Type.Combinator
import           Data.Type.Index
import           Data.Type.Length
import           Data.Type.Product
import           GHC.Generics               (Generic)
import           Lens.Micro
import           Lens.Micro.Internal hiding (Index)
import           Type.Class.Known
import           Type.Family.List

#if !MIN_VERSION_base(4,11,0)
import           Data.Semigroup
#endif

-- | Unit ('()') with 'Num', 'Fractional', and 'Floating' instances.
--
-- Be aware that the methods in its numerical instances are all non-strict:
--
-- @@
-- _ + _ = 'T0'
-- 'negate' _ = 'T0'
-- 'fromIntegral' _ = 'T0'
-- @@
--
-- @since 0.1.4.0
data T0 = T0
  deriving (Show, Read, Eq, Ord, Generic, Data)

instance NFData T0

-- | Strict 2-tuple with 'Num', 'Fractional', and 'Floating' instances.
--
-- @since 0.1.1.0
data T2 a b   = T2 !a !b
  deriving (Show, Read, Eq, Ord, Generic, Functor, Data)

-- | Strict 3-tuple with a 'Num', 'Fractional', and 'Floating' instances.
--
-- @since 0.1.1.0
data T3 a b c = T3 !a !b !c
  deriving (Show, Read, Eq, Ord, Generic, Functor, Data)

-- | Strict inductive N-tuple with a 'Num', 'Fractional', and 'Floating'
-- instances.
--
-- It is basically "yet another HList", like the one found in
-- "Data.Type.Product" and many other locations on the haskell ecosystem.
-- Because it's inductively defined, it has O(n) random indexing, but is
-- efficient for zipping and mapping and other sequential consumption
-- patterns.
--
-- It is provided because of its 'Num' instance, making it useful for
-- /backproup/.  Will be obsolete when 'Data.Type.Product.Product' gets
-- numerical instances.
--
-- @since 0.1.5.0
data T :: [Type] -> Type where
    TNil :: T '[]
    (:&) :: !a -> !(T as) -> T (a ': as)

instance (NFData a, NFData b) => NFData (T2 a b)
instance (NFData a, NFData b, NFData c) => NFData (T3 a b c)
instance ListC (NFData <$> as) => NFData (T as) where
    rnf = \case
      TNil       -> ()
      (!_) :& xs -> rnf xs

instance Bifunctor T2 where
    bimap f g (T2 x y) = T2 (f x) (g y)

instance Bifunctor (T3 a) where
    bimap f g (T3 x y z) = T3 x (f y) (g z)

-- | Convert to a Haskell tuple.
--
-- Forms an isomorphism with 'tupT2'.
t2Tup :: T2 a b -> (a, b)
t2Tup (T2 x y) = (x, y)

-- | Convert from Haskell tuple.
--
-- Forms an isomorphism with 't2Tup'.
tupT2 :: (a, b) -> T2 a b
tupT2 (x, y) = T2 x y

-- | Convert to a Haskell tuple.
--
-- Forms an isomorphism with 'tupT3'.
t3Tup :: T3 a b c -> (a, b, c)
t3Tup (T3 x y z) = (x, y, z)

-- | Convert from Haskell tuple.
--
-- Forms an isomorphism with 't3Tup'.
tupT3 :: (a, b, c) -> T3 a b c
tupT3 (x, y, z) = T3 x y z

-- | A singleton 'T'
--
-- Forms an isomorphism with 'tOnly'
--
-- @since 0.1.5.0
onlyT :: a -> T '[a]
onlyT = (:& TNil)

-- | Extract a singleton 'T'
--
-- Forms an isomorphism with 'onlyT'
--
-- @since 0.1.5.0
tOnly :: T '[a] -> a
tOnly (x :& _) = x

-- | Uncurry a function to take in a 'T2' of its arguments
--
-- @since 0.1.2.0
uncurryT2 :: (a -> b -> c) -> T2 a b -> c
uncurryT2 f (T2 x y) = f x y

-- | Curry a function taking a 'T2' of its arguments
--
-- @since 0.1.2.0
curryT2 :: (T2 a b -> c) -> a -> b -> c
curryT2 f x y = f (T2 x y)

-- | Uncurry a function to take in a 'T3' of its arguments
--
-- @since 0.1.2.0
uncurryT3 :: (a -> b -> c -> d) -> T3 a b c -> d
uncurryT3 f (T3 x y z) = f x y z

-- | Curry a function taking a 'T3' of its arguments
--
-- @since 0.1.2.0
curryT3 :: (T3 a b c -> d) -> a -> b -> c -> d
curryT3 f x y z = f (T3 x y z)

instance Field1 (T2 a b) (T2 a' b) a a' where
    _1 = t2_1

instance Field2 (T2 a b) (T2 a b') b b' where
    _2 = t2_2

instance Field1 (T3 a b c) (T3 a' b c) a a' where
    _1 = t3_1

instance Field2 (T3 a b c) (T3 a b' c) b b' where
    _2 = t3_2

instance Field3 (T3 a b c) (T3 a b c') c c' where
    _3 = t3_3

instance Field1 (T (a ': as)) (T (a ': as)) a a where
    _1 = tIx IZ

instance Field2 (T (a ': b ': as)) (T (a ': b ': as)) b b where
    _2 = tIx (IS IZ)

instance Field3 (T (a ': b ': c ': as)) (T (a ': b ': c ': as)) c c where
    _3 = tIx (IS (IS IZ))

-- | Lens into the first field of a 'T2'.  Also exported as '_1' from
-- "Lens.Micro".
t2_1 :: Lens (T2 a b) (T2 a' b) a a'
t2_1 f (T2 x y) = (`T2` y) <$> f x

-- | Lens into the second field of a 'T2'.  Also exported as '_2' from
-- "Lens.Micro".
t2_2 :: Lens (T2 a b) (T2 a b') b b'
t2_2 f (T2 x y) = T2 x <$> f y

-- | Lens into the first field of a 'T3'.  Also exported as '_1' from
-- "Lens.Micro".
t3_1 :: Lens (T3 a b c) (T3 a' b c) a a'
t3_1 f (T3 x y z) = (\x' -> T3 x' y z) <$> f x

-- | Lens into the second field of a 'T3'.  Also exported as '_2' from
-- "Lens.Micro".
t3_2 :: Lens (T3 a b c) (T3 a b' c) b b'
t3_2 f (T3 x y z) = (\y' -> T3 x y' z) <$> f y

-- | Lens into the third field of a 'T3'.  Also exported as '_3' from
-- "Lens.Micro".
t3_3 :: Lens (T3 a b c) (T3 a b c') c c'
t3_3 f (T3 x y z) = T3 x y <$> f z

-- | Index into a 'T'.
--
-- /O(i)/
--
-- @since 0.1.5.0
indexT :: Index as a -> T as -> a
indexT = flip (^.) . tIx

-- | Lens into a given index of a 'T'.
--
-- @since 0.1.5.0
tIx :: Index as a -> Lens' (T as) a
tIx IZ     f (x :& xs) = (:& xs) <$> f x
tIx (IS i) f (x :& xs) = (x :&)  <$> tIx i f xs

-- | Lens into the head of a 'T'
--
-- @since 0.1.5.0
tHead :: Lens (T (a ': as)) (T (b ': as)) a b
tHead f (x :& xs) = (:& xs) <$> f x

-- | Lens into the tail of a 'T'
--
-- @since 0.1.5.0
tTail :: Lens (T (a ': as)) (T (a ': bs)) (T as) (T bs)
tTail f (x :& xs) = (x :&) <$> f xs

-- | Append two 'T's.
--
-- Forms an isomorphism with 'tSplit'.
--
-- @since 0.1.5.0
tAppend :: T as -> T bs -> T (as ++ bs)
tAppend TNil      ys = ys
tAppend (x :& xs) ys = x :& tAppend xs ys
infixr 5 `tAppend`

-- | Split a 'T'.  For splits known at compile-time, you can use 'known' to
-- derive the 'Length' automatically.
--
-- Forms an isomorphism with 'tAppend'.
--
-- @since 0.1.5.0
tSplit :: Length as -> T (as ++ bs) -> (T as, T bs)
tSplit LZ     xs        = (TNil, xs)
tSplit (LS l) (x :& xs) = first (x :&) . tSplit l $ xs

-- | Lens into the initial portion of a 'T'.  For splits known at
-- compile-time, you can use 'known' to derive the 'Length' automatically.
--
-- @since 0.1.5.0
tTake :: forall as bs cs. Length as -> Lens (T (as ++ bs)) (T (cs ++ bs)) (T as) (T cs)
tTake l f (tSplit l->(xs,ys)) = flip (tAppend @cs @bs) ys <$> f xs

-- | Lens into the ending portion of a 'T'.  For splits known at
-- compile-time, you can use 'known' to derive the 'Length' automatically.
--
-- @since 0.1.5.0
tDrop :: forall as bs cs. Length as -> Lens (T (as ++ bs)) (T (as ++ cs)) (T bs) (T cs)
tDrop l f (tSplit l->(xs,ys)) = tAppend xs <$> f ys

-- | Convert a 'T' to a 'Tuple'.
--
-- Forms an isomorphism with 'prodT'.
--
-- @since 0.1.5.0
tProd :: T as -> Tuple as
tProd TNil      = Ø
tProd (x :& xs) = x ::< tProd xs

-- | Convert a 'Tuple' to a 'T'.
--
-- Forms an isomorphism with 'tProd'.
--
-- @since 0.1.5.0
prodT :: Tuple as -> T as
prodT Ø           = TNil
prodT (I x :< xs) = x :& prodT xs


instance Num T0 where
    _ + _         = T0
    _ - _         = T0
    _ * _         = T0
    negate _      = T0
    abs    _      = T0
    signum _      = T0
    fromInteger _ = T0

instance Fractional T0 where
    _ / _          = T0
    recip _        = T0
    fromRational _ = T0

instance Floating T0 where
    pi          = T0
    _ ** _      = T0
    logBase _ _ = T0
    exp   _     = T0
    log   _     = T0
    sqrt  _     = T0
    sin   _     = T0
    cos   _     = T0
    asin  _     = T0
    acos  _     = T0
    atan  _     = T0
    sinh  _     = T0
    cosh  _     = T0
    asinh _     = T0
    acosh _     = T0
    atanh _     = T0

instance Semigroup T0 where
    _ <> _ = T0

instance Monoid T0 where
    mempty = T0
    mappend = (<>)

instance (Num a, Num b) => Num (T2 a b) where
    T2 x1 y1 + T2 x2 y2 = T2 (x1 + x2) (y1 + y2)
    T2 x1 y1 - T2 x2 y2 = T2 (x1 - x2) (y1 - y2)
    T2 x1 y1 * T2 x2 y2 = T2 (x1 * x2) (y1 * y2)
    negate (T2 x y)     = T2 (negate x) (negate y)
    abs    (T2 x y)     = T2 (abs    x) (abs    y)
    signum (T2 x y)     = T2 (signum x) (signum y)
    fromInteger x       = T2 (fromInteger x) (fromInteger x)

instance (Fractional a, Fractional b) => Fractional (T2 a b) where
    T2 x1 y1 / T2 x2 y2 = T2 (x1 / x2) (y1 / y2)
    recip (T2 x y)      = T2 (recip x) (recip y)
    fromRational x      = T2 (fromRational x) (fromRational x)

instance (Floating a, Floating b) => Floating (T2 a b) where
    pi                            = T2 pi pi
    T2 x1 y1 ** T2 x2 y2          = T2 (x1 ** x2) (y1 ** y2)
    logBase (T2 x1 y1) (T2 x2 y2) = T2 (logBase x1 x2) (logBase y1 y2)
    exp   (T2 x y)                = T2 (exp   x) (exp   y)
    log   (T2 x y)                = T2 (log   x) (log   y)
    sqrt  (T2 x y)                = T2 (sqrt  x) (sqrt  y)
    sin   (T2 x y)                = T2 (sin   x) (sin   y)
    cos   (T2 x y)                = T2 (cos   x) (cos   y)
    asin  (T2 x y)                = T2 (asin  x) (asin  y)
    acos  (T2 x y)                = T2 (acos  x) (acos  y)
    atan  (T2 x y)                = T2 (atan  x) (atan  y)
    sinh  (T2 x y)                = T2 (sinh  x) (sinh  y)
    cosh  (T2 x y)                = T2 (cosh  x) (cosh  y)
    asinh (T2 x y)                = T2 (asinh x) (asinh y)
    acosh (T2 x y)                = T2 (acosh x) (acosh y)
    atanh (T2 x y)                = T2 (atanh x) (atanh y)

instance (Semigroup a, Semigroup b) => Semigroup (T2 a b) where
    T2 x1 y1 <> T2 x2 y2 = T2 (x1 <> x2) (y1 <> y2)

#if MIN_VERSION_base(4,11,0)
instance (Monoid a, Monoid b) => Monoid (T2 a b) where
#else
instance (Semigroup a, Semigroup b, Monoid a, Monoid b) => Monoid (T2 a b) where
#endif
    mappend = (<>)
    mempty  = T2 mempty mempty

instance (Num a, Num b, Num c) => Num (T3 a b c) where
    T3 x1 y1 z1 + T3 x2 y2 z2 = T3 (x1 + x2) (y1 + y2) (z1 + z2)
    T3 x1 y1 z1 - T3 x2 y2 z2 = T3 (x1 - x2) (y1 - y2) (z1 + z2)
    T3 x1 y1 z1 * T3 x2 y2 z2 = T3 (x1 * x2) (y1 * y2) (z1 + z2)
    negate (T3 x y z)         = T3 (negate x) (negate y) (negate z)
    abs    (T3 x y z)         = T3 (abs    x) (abs    y) (abs    z)
    signum (T3 x y z)         = T3 (signum x) (signum y) (signum z)
    fromInteger x             = T3 (fromInteger x) (fromInteger x) (fromInteger x)

instance (Fractional a, Fractional b, Fractional c) => Fractional (T3 a b c) where
    T3 x1 y1 z1 / T3 x2 y2 z2 = T3 (x1 / x2) (y1 / y2) (z1 / z2)
    recip (T3 x y z)          = T3 (recip x) (recip y) (recip z)
    fromRational x            = T3 (fromRational x) (fromRational x) (fromRational x)

instance (Floating a, Floating b, Floating c) => Floating (T3 a b c) where
    pi                                  = T3 pi pi pi
    T3 x1 y1 z1 ** T3 x2 y2 z2          = T3 (x1 ** x2) (y1 ** y2) (z1 ** z2)
    logBase (T3 x1 y1 z1) (T3 x2 y2 z2) = T3 (logBase x1 x2) (logBase y1 y2) (logBase z1 z2)
    exp   (T3 x y z)                    = T3 (exp   x) (exp   y) (exp   z)
    log   (T3 x y z)                    = T3 (log   x) (log   y) (log   z)
    sqrt  (T3 x y z)                    = T3 (sqrt  x) (sqrt  y) (sqrt  z)
    sin   (T3 x y z)                    = T3 (sin   x) (sin   y) (sin   z)
    cos   (T3 x y z)                    = T3 (cos   x) (cos   y) (cos   z)
    asin  (T3 x y z)                    = T3 (asin  x) (asin  y) (asin  z)
    acos  (T3 x y z)                    = T3 (acos  x) (acos  y) (acos  z)
    atan  (T3 x y z)                    = T3 (atan  x) (atan  y) (atan  z)
    sinh  (T3 x y z)                    = T3 (sinh  x) (sinh  y) (sinh  z)
    cosh  (T3 x y z)                    = T3 (cosh  x) (cosh  y) (cosh  z)
    asinh (T3 x y z)                    = T3 (asinh x) (asinh y) (asinh z)
    acosh (T3 x y z)                    = T3 (acosh x) (acosh y) (acosh z)
    atanh (T3 x y z)                    = T3 (atanh x) (atanh y) (atanh z)

instance (Semigroup a, Semigroup b, Semigroup c) => Semigroup (T3 a b c) where
    T3 x1 y1 z1 <> T3 x2 y2 z2 = T3 (x1 <> x2) (y1 <> y2) (z1 <> z2)

#if MIN_VERSION_base(4,11,0)
instance (Monoid a, Monoid b, Monoid c) => Monoid (T3 a b c) where
#else
instance (Semigroup a, Semigroup b, Semigroup c, Monoid a, Monoid b, Monoid c) => Monoid (T3 a b c) where
#endif
    mappend = (<>)
    mempty  = T3 mempty mempty mempty

-- | Initialize a 'T' with a Rank-N value.  Mostly used internally, but
-- provided in case useful.
--
-- Must be used with /TypeApplications/ to provide the Rank-N constraint.
--
-- @since 0.1.5.0
constT
    :: forall c as. ListC (c <$> as)
    => (forall a. c a => a)
    -> Length as
    -> T as
constT x = go
  where
    go :: forall bs. ListC (c <$> bs) => Length bs -> T bs
    go LZ     = TNil
    go (LS l) = x :& go l

-- | Map over a 'T' with a Rank-N function.  Mostly used internally, but
-- provided in case useful.
--
-- Must be used with /TypeApplications/ to provide the Rank-N constraint.
--
-- @since 0.1.5.0
mapT
    :: forall c as. ListC (c <$> as)
    => (forall a. c a => a -> a)
    -> T as
    -> T as
mapT f = go
  where
    go :: forall bs. ListC (c <$> bs) => T bs -> T bs
    go TNil      = TNil
    go (x :& xs) = f x :& go xs

-- | Map over a 'T' with a Rank-N function.  Mostly used internally, but
-- provided in case useful.
--
-- Must be used with /TypeApplications/ to provide the Rank-N constraint.
--
-- @since 0.1.5.0
zipT
    :: forall c as. ListC (c <$> as)
    => (forall a. c a => a -> a -> a)
    -> T as
    -> T as
    -> T as
zipT f = go
  where
    go :: forall bs. ListC (c <$> bs) => T bs -> T bs -> T bs
    go TNil      TNil      = TNil
    go (x :& xs) (y :& ys) = f x y :& go xs ys

instance (Known Length as, ListC (Num <$> as)) => Num (T as) where
    (+)           = zipT @Num (+)
    (-)           = zipT @Num (-)
    (*)           = zipT @Num (*)
    negate        = mapT @Num negate
    abs           = mapT @Num abs
    signum        = mapT @Num signum
    fromInteger x = constT @Num (fromInteger x) known

instance (Known Length as, ListC (Num <$> as), ListC (Fractional <$> as)) => Fractional (T as) where
    (/)            = zipT @Fractional (/)
    recip          = mapT @Fractional recip
    fromRational x = constT @Fractional (fromRational x) known

instance (Known Length as, ListC (Num <$> as), ListC (Fractional <$> as), ListC (Floating <$> as))
        => Floating (T as) where
    pi      = constT @Floating pi known
    (**)    = zipT @Floating (**)
    logBase = zipT @Floating logBase
    exp     = mapT @Floating exp
    log     = mapT @Floating log
    sqrt    = mapT @Floating sqrt
    sin     = mapT @Floating sin
    cos     = mapT @Floating cos
    asin    = mapT @Floating asin
    acos    = mapT @Floating acos
    atan    = mapT @Floating atan
    sinh    = mapT @Floating sinh
    cosh    = mapT @Floating cosh
    asinh   = mapT @Floating asinh
    acosh   = mapT @Floating acosh
    atanh   = mapT @Floating atanh

instance ListC (Semigroup <$> as) => Semigroup (T as) where
    (<>) = zipT @Semigroup (<>)

instance (Known Length as, ListC (Semigroup <$> as), ListC (Monoid <$> as)) => Monoid (T as) where
    mempty  = constT @Monoid mempty known
    mappend = (<>)

-- $t2iso
--
-- If using /lens/, the two conversion functions can be chained with prisms
-- and traversals and other optics using:
--
-- @
-- 'iso' 'tupT2' 't2Tup' :: 'Iso'' (a, b) ('T2' a b)
-- @

-- $t3iso
--
-- If using /lens/, the two conversion functions can be chained with prisms
-- and traversals and other optics using:
--
-- @
-- 'iso' 'tupT3' 't2Tup' :: 'Iso'' (a, b, c) ('T3' a b c)
-- @

-- $tiso
--
-- If using /lens/, the two conversion functions can be chained with prisms
-- and traversals and other optics using:
--
-- @
-- 'iso' 'onlyT' 'tOnly' :: 'Iso'' a (T '[a])
-- @