module Numeric.Backprop.Tuple (
T2(..)
, t2Tup, tupT2
, uncurryT2, curryT2
, t2_1, t2_2
, T3(..)
, t3Tup, tupT3
, t3_1, t3_2, t3_3
, uncurryT3, curryT3
) where
import Control.DeepSeq
import Data.Bifunctor
import Data.Data
import Data.Semigroup
import GHC.Generics (Generic)
import Lens.Micro
import Lens.Micro.Internal
data T2 a b = T2 !a !b
deriving (Show, Read, Eq, Ord, Generic, Functor, Data)
data T3 a b c = T3 !a !b !c
deriving (Show, Read, Eq, Ord, Generic, Functor, Data)
instance (NFData a, NFData b) => NFData (T2 a b)
instance (NFData a, NFData b, NFData c) => NFData (T3 a b c)
instance Bifunctor T2 where
bimap f g (T2 x y) = T2 (f x) (g y)
instance Bifunctor (T3 a) where
bimap f g (T3 x y z) = T3 x (f y) (g z)
t2Tup :: T2 a b -> (a, b)
t2Tup (T2 x y) = (x, y)
tupT2 :: (a, b) -> T2 a b
tupT2 (x, y) = T2 x y
t3Tup :: T3 a b c -> (a, b, c)
t3Tup (T3 x y z) = (x, y, z)
tupT3 :: (a, b, c) -> T3 a b c
tupT3 (x, y, z) = T3 x y z
uncurryT2 :: (a -> b -> c) -> T2 a b -> c
uncurryT2 f (T2 x y) = f x y
curryT2 :: (T2 a b -> c) -> a -> b -> c
curryT2 f x y = f (T2 x y)
uncurryT3 :: (a -> b -> c -> d) -> T3 a b c -> d
uncurryT3 f (T3 x y z) = f x y z
curryT3 :: (T3 a b c -> d) -> a -> b -> c -> d
curryT3 f x y z = f (T3 x y z)
instance Field1 (T2 a b) (T2 a' b) a a' where
_1 f (T2 x y) = (`T2` y) <$> f x
instance Field2 (T2 a b) (T2 a b') b b' where
_2 f (T2 x y) = T2 x <$> f y
instance Field1 (T3 a b c) (T3 a' b c) a a' where
_1 f (T3 x y z) = (\x' -> T3 x' y z) <$> f x
instance Field2 (T3 a b c) (T3 a b' c) b b' where
_2 f (T3 x y z) = (\y' -> T3 x y' z) <$> f y
instance Field3 (T3 a b c) (T3 a b c') c c' where
_3 f (T3 x y z) = T3 x y <$> f z
t2_1 :: Lens (T2 a b) (T2 a' b) a a'
t2_1 = _1
t2_2 :: Lens (T2 a b) (T2 a b') b b'
t2_2 = _2
t3_1 :: Lens (T3 a b c) (T3 a' b c) a a'
t3_1 = _1
t3_2 :: Lens (T3 a b c) (T3 a b' c) b b'
t3_2 = _2
t3_3 :: Lens (T3 a b c) (T3 a b c') c c'
t3_3 = _3
instance (Num a, Num b) => Num (T2 a b) where
T2 x1 y1 + T2 x2 y2 = T2 (x1 + x2) (y1 + y2)
T2 x1 y1 T2 x2 y2 = T2 (x1 x2) (y1 y2)
T2 x1 y1 * T2 x2 y2 = T2 (x1 * x2) (y1 * y2)
negate (T2 x y) = T2 (negate x) (negate y)
abs (T2 x y) = T2 (abs x) (abs y)
signum (T2 x y) = T2 (signum x) (signum y)
fromInteger x = T2 (fromInteger x) (fromInteger x)
instance (Fractional a, Fractional b) => Fractional (T2 a b) where
T2 x1 y1 / T2 x2 y2 = T2 (x1 / x2) (y1 / y2)
recip (T2 x y) = T2 (recip x) (recip y)
fromRational x = T2 (fromRational x) (fromRational x)
instance (Floating a, Floating b) => Floating (T2 a b) where
pi = T2 pi pi
T2 x1 y1 ** T2 x2 y2 = T2 (x1 ** x2) (y1 ** y2)
logBase (T2 x1 y1) (T2 x2 y2) = T2 (logBase x1 x2) (logBase y1 y2)
exp (T2 x y) = T2 (exp x) (exp y)
log (T2 x y) = T2 (log x) (log y)
sqrt (T2 x y) = T2 (sqrt x) (sqrt y)
sin (T2 x y) = T2 (sin x) (sin y)
cos (T2 x y) = T2 (cos x) (cos y)
asin (T2 x y) = T2 (asin x) (asin y)
acos (T2 x y) = T2 (acos x) (acos y)
atan (T2 x y) = T2 (atan x) (atan y)
sinh (T2 x y) = T2 (sinh x) (sinh y)
cosh (T2 x y) = T2 (cosh x) (cosh y)
asinh (T2 x y) = T2 (asinh x) (asinh y)
acosh (T2 x y) = T2 (acosh x) (acosh y)
atanh (T2 x y) = T2 (atanh x) (atanh y)
instance (Semigroup a, Semigroup b) => Semigroup (T2 a b) where
T2 x1 y1 <> T2 x2 y2 = T2 (x1 <> x2) (y1 <> y2)
instance (Monoid a, Monoid b) => Monoid (T2 a b) where
mappend (T2 x1 y1) (T2 x2 y2) = T2 (mappend x1 x2) (mappend y1 y2)
mempty = T2 mempty mempty
instance (Num a, Num b, Num c) => Num (T3 a b c) where
T3 x1 y1 z1 + T3 x2 y2 z2 = T3 (x1 + x2) (y1 + y2) (z1 + z2)
T3 x1 y1 z1 T3 x2 y2 z2 = T3 (x1 x2) (y1 y2) (z1 + z2)
T3 x1 y1 z1 * T3 x2 y2 z2 = T3 (x1 * x2) (y1 * y2) (z1 + z2)
negate (T3 x y z) = T3 (negate x) (negate y) (negate z)
abs (T3 x y z) = T3 (abs x) (abs y) (abs z)
signum (T3 x y z) = T3 (signum x) (signum y) (signum z)
fromInteger x = T3 (fromInteger x) (fromInteger x) (fromInteger x)
instance (Fractional a, Fractional b, Fractional c) => Fractional (T3 a b c) where
T3 x1 y1 z1 / T3 x2 y2 z2 = T3 (x1 / x2) (y1 / y2) (z1 / z2)
recip (T3 x y z) = T3 (recip x) (recip y) (recip z)
fromRational x = T3 (fromRational x) (fromRational x) (fromRational x)
instance (Floating a, Floating b, Floating c) => Floating (T3 a b c) where
pi = T3 pi pi pi
T3 x1 y1 z1 ** T3 x2 y2 z2 = T3 (x1 ** x2) (y1 ** y2) (z1 ** z2)
logBase (T3 x1 y1 z1) (T3 x2 y2 z2) = T3 (logBase x1 x2) (logBase y1 y2) (logBase z1 z2)
exp (T3 x y z) = T3 (exp x) (exp y) (exp z)
log (T3 x y z) = T3 (log x) (log y) (log z)
sqrt (T3 x y z) = T3 (sqrt x) (sqrt y) (sqrt z)
sin (T3 x y z) = T3 (sin x) (sin y) (sin z)
cos (T3 x y z) = T3 (cos x) (cos y) (cos z)
asin (T3 x y z) = T3 (asin x) (asin y) (asin z)
acos (T3 x y z) = T3 (acos x) (acos y) (acos z)
atan (T3 x y z) = T3 (atan x) (atan y) (atan z)
sinh (T3 x y z) = T3 (sinh x) (sinh y) (sinh z)
cosh (T3 x y z) = T3 (cosh x) (cosh y) (cosh z)
asinh (T3 x y z) = T3 (asinh x) (asinh y) (asinh z)
acosh (T3 x y z) = T3 (acosh x) (acosh y) (acosh z)
atanh (T3 x y z) = T3 (atanh x) (atanh y) (atanh z)
instance (Semigroup a, Semigroup b, Semigroup c) => Semigroup (T3 a b c) where
T3 x1 y1 z1 <> T3 x2 y2 z2 = T3 (x1 <> x2) (y1 <> y2) (z1 <> z2)
instance (Monoid a, Monoid b, Monoid c) => Monoid (T3 a b c) where
mappend (T3 x1 y1 z1) (T3 x2 y2 z2) = T3 (mappend x1 x2) (mappend y1 y2) (mappend z1 z2)
mempty = T3 mempty mempty mempty