{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
#if ( __GLASGOW_HASKELL__ < 820 )
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-unrecognised-pragmas #-}
#endif
module NumHask.Pair
  ( Pair(..)
  , pattern Pair
  ) where
import Data.Distributive
import Data.Functor.Apply (Apply(..))
import Data.Functor.Classes
import Data.Functor.Rep
import Data.Semigroup.Foldable (Foldable1(..))
import Data.Semigroup.Traversable (Traversable1(..))
import NumHask.Prelude
import Test.QuickCheck.Arbitrary (Arbitrary(..))
import Text.Show
newtype Pair a =
  Pair' (a, a)
  deriving (Eq, Generic)
pattern Pair :: a -> a -> Pair a
pattern Pair a b = Pair' (a,b)
{-# COMPLETE Pair#-}
instance (Show a) => Show (Pair a) where
  show (Pair a b) = "Pair " <> Text.Show.show a <> " " <> Text.Show.show b
instance Functor Pair where
  fmap f (Pair a b) = Pair (f a) (f b)
instance Eq1 Pair where
  liftEq f (Pair a b) (Pair c d) = f a c && f b d
instance Show1 Pair where
  liftShowsPrec sp _ d (Pair' (a, b)) = showsBinaryWith sp sp "Pair" d a b
instance Apply Pair where
  Pair fa fb <.> Pair a b = Pair (fa a) (fb b)
instance Applicative Pair where
  pure a = Pair a a
  (Pair fa fb) <*> Pair a b = Pair (fa a) (fb b)
instance Monad Pair where
  Pair a b >>= f = Pair a' b'
    where
      Pair a' _ = f a
      Pair _ b' = f b
instance Foldable Pair where
  foldMap f (Pair a b) = f a `mappend` f b
instance Foldable1 Pair
    
instance Traversable Pair where
  traverse f (Pair a b) = Pair <$> f a <*> f b
instance Traversable1 Pair where
  traverse1 f (Pair a b) = Pair <$> f a Data.Functor.Apply.<.> f b
instance (Semigroup a) => Semigroup (Pair a) where
  (Pair a0 b0) <> (Pair a1 b1) = Pair (a0 <> a1) (b0 <> b1)
instance (Semigroup a, Monoid a) => Monoid (Pair a) where
  mempty = Pair mempty mempty
  mappend = (<>)
instance Distributive Pair where
  collect f x = Pair (getL . f <$> x) (getR . f <$> x)
    where
      getL (Pair l _) = l
      getR (Pair _ r) = r
instance Representable Pair where
  type Rep Pair = Bool
  tabulate f = Pair (f False) (f True)
  index (Pair l _) False = l
  index (Pair _ r) True = r
instance NFData a => NFData (Pair a) where
  rnf (Pair a b) = rnf a `seq` rnf b
instance (Arbitrary a) => Arbitrary (Pair a) where
  arbitrary = do
    a <- arbitrary
    b <- arbitrary
    pure (Pair a b)
instance (AdditiveMagma a) => AdditiveMagma (Pair a) where
  plus (Pair a0 b0) (Pair a1 b1) = Pair (a0 `plus` a1) (b0 `plus` b1)
instance (AdditiveUnital a) => AdditiveUnital (Pair a) where
  zero = Pair zero zero
instance (AdditiveMagma a) => AdditiveAssociative (Pair a)
instance (AdditiveMagma a) => AdditiveCommutative (Pair a)
instance (AdditiveUnital a) => Additive (Pair a)
instance (AdditiveInvertible a) => AdditiveInvertible (Pair a) where
  negate (Pair a b) = Pair (negate a) (negate b)
instance (AdditiveUnital a, AdditiveInvertible a) =>
         AdditiveGroup (Pair a)
instance (MultiplicativeMagma a) => MultiplicativeMagma (Pair a) where
  times (Pair a0 b0) (Pair a1 b1) = Pair (a0 `times` a1) (b0 `times` b1)
instance (MultiplicativeUnital a) => MultiplicativeUnital (Pair a) where
  one = Pair one one
instance (MultiplicativeMagma a) => MultiplicativeAssociative (Pair a)
instance (MultiplicativeMagma a) => MultiplicativeCommutative (Pair a)
instance (MultiplicativeUnital a) => Multiplicative (Pair a)
instance (MultiplicativeInvertible a) => MultiplicativeInvertible (Pair a) where
  recip (Pair a b) = Pair (recip a) (recip b)
instance (MultiplicativeUnital a, MultiplicativeInvertible a) =>
         MultiplicativeGroup (Pair a)
instance (Integral a) => Integral (Pair a) where
  (Pair a0 b0) `divMod` (Pair a1 b1) = (Pair da db, Pair ma mb)
    where
      (da, ma) = a0 `divMod` a1
      (db, mb) = b0 `divMod` b1
  (Pair a0 b0) `quotRem` (Pair a1 b1) = (Pair da db, Pair ma mb)
    where
      (da, ma) = a0 `quotRem` a1
      (db, mb) = b0 `quotRem` b1
instance (Signed a) => Signed (Pair a) where
  sign (Pair a b) = Pair (sign a) (sign b)
  abs (Pair a b) = Pair (abs a) (abs b)
instance (ExpField a, Normed a a, MultiplicativeUnital a) =>
         Normed (Pair a) a where
  normL1 (Pair a b) = normL1 a + normL1 b
  normL2 (Pair a b) = sqrt (a ** (one + one) + b ** (one + one))
  normLp p (Pair a b) = (normL1 a ** p + normL1 b ** p) ** (one/p)
instance (Eq a, Ord a, Signed a, Additive a) => Ord (Pair a) where
  (<=) (Pair x y) (Pair x' y') = (abs x + abs y) <= (abs x' + abs y')
instance (Epsilon a) => Epsilon (Pair a) where
  nearZero (Pair a b) = nearZero a && nearZero b
  aboutEqual a b = nearZero $ a - b
instance (ExpField a, Normed a a) => Metric (Pair a) a where
  distanceL1 a b = normL1 (a - b)
  distanceL2 a b = normL2 (a - b)
  distanceLp p a b = normLp p (a - b)
instance (Distribution a) => Distribution (Pair a)
instance (Semiring a) => Semiring (Pair a)
instance (Ring a) => Ring (Pair a)
instance (CRing a) => CRing (Pair a)
instance (Semifield a) => Semifield (Pair a)
instance (Field a) => Field (Pair a)
instance (ExpField a) => ExpField (Pair a) where
  exp (Pair a b) = Pair (exp a) (exp b)
  log (Pair a b) = Pair (log a) (log b)
instance (UpperBoundedField a) => UpperBoundedField (Pair a)
instance (LowerBoundedField a) => LowerBoundedField (Pair a)
instance (Additive a) => AdditiveBasis Pair a where
    (.+.) = liftR2 (+)
instance (AdditiveGroup a) => AdditiveGroupBasis Pair a where
    (.-.) = liftR2 (-)
instance (Multiplicative a) => MultiplicativeBasis Pair a where
    (.*.) = liftR2 (*)
instance (MultiplicativeGroup a) => MultiplicativeGroupBasis Pair a where
    (./.) = liftR2 (/)
instance (Additive a) => AdditiveModule Pair a where
    (.+) r s = fmap (s+) r
    (+.) s r = fmap (s+) r
instance (AdditiveGroup a) => AdditiveGroupModule Pair a where
    (.-) r s = fmap (\x -> x - s) r
    (-.) s r = fmap (\x -> x - s) r
instance (Multiplicative a) => MultiplicativeModule Pair a where
    (.*) r s = fmap (s*) r
    (*.) s r = fmap (s*) r
instance (MultiplicativeGroup a) => MultiplicativeGroupModule Pair a where
    (./) r s = fmap (/ s) r
    (/.) s r = fmap (/ s) r
instance Singleton Pair where
    singleton a = Pair a a