{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wall #-}
#if ( __GLASGOW_HASKELL__ < 820 )
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-unrecognised-pragmas #-}
#endif

{-#

I would have used V2 from the linear package, but wanted to avoid the lens dependency.

#-}

module NumHask.Pair
  ( Pair(..)
  , pattern Pair
  ) where

import NumHask.Prelude

import Data.Functor.Apply (Apply(..))
import Data.Semigroup.Foldable (Foldable1(..))
import Data.Semigroup.Traversable (Traversable1(..))
import Data.Functor.Rep
import Data.Functor.Classes
import Data.Distributive
import Test.QuickCheck.Arbitrary (Arbitrary(..))

-- $setup
-- >>> :set -XNoImplicitPrelude

-- | A Pair
--
-- >>> fmap (+1) (Pair 1 2)
-- Pair 2 3
--
-- >>> pure one :: Pair Int
-- Pair 1 1
--
-- >>> (*) <$> Pair 1 2 <*> pure 2
-- Pair 2 4
--
-- >>> foldr (++) [] (Pair [1,2] [3])
-- [1,2,3]
--
-- >>> Pair "a" "pair" <> pure " " <> Pair "string" "mappend"
-- Pair "a string" "pair mappend"
--
-- | numerics
-- >>> Pair 0 1 + zero
-- Pair 0 1
--
-- >>> Pair 0 1 + Pair 2 3
-- Pair 2 4
--
-- >>> Pair 1 1 - one
-- Pair 0 0
--
-- >>> Pair 0 1 * one
-- Pair 0 1
--
-- >>> Pair 0 1 / one
-- Pair 0.0 1.0
--
-- >>> Pair 11 12 `mod` (pure 6)
-- Pair 5 0
--
-- | module
-- >>> Pair 1 2 .+ 3
-- Pair 4 5
--
-- | representations
-- >>>  distribute [Pair 1 2, Pair 3 4]
-- Pair [1,3] [2,4]
--
-- >>> index (Pair 'l' 'r') LPair
-- 'l'
-- 


-- | A pair of a's, implemented as a tuple, but api represented as a Pair of a's.
newtype Pair a = Pair' (a,a)
    deriving (Show, Eq, Ord, Generic)

pattern Pair :: a -> a -> Pair a
pattern Pair a b = Pair' (a,b)
{-# COMPLETE Pair#-}

instance Functor Pair where
    fmap f (Pair a b) = Pair (f a) (f b)

instance Eq1 Pair where
    liftEq f (Pair a b) (Pair c d) = f a c && f b d

instance Show1 Pair where
    liftShowsPrec sp _ d (Pair' (a,b)) = showsBinaryWith sp sp "Pair" d a b

instance Apply Pair where
  Pair fa fb <.> Pair a b = Pair (fa a) (fb b)

instance Applicative Pair where
    pure a = Pair a a
    (Pair fa fb) <*> Pair a b = Pair (fa a) (fb b)

instance Monad Pair where
  Pair a b >>= f = Pair a' b' where
    Pair a' _ = f a
    Pair _ b' = f b

instance Foldable Pair where
    foldMap f (Pair a b) = f a `mappend` f b

instance Foldable1 Pair
    -- foldMap1 f (Pair a b) = f a <> f b

instance Traversable Pair where
    traverse f (Pair a b) = Pair <$> f a <*> f b

instance Traversable1 Pair where
    traverse1 f (Pair a b) = Pair <$> f a Data.Functor.Apply.<.> f b

instance (Monoid a) => Monoid (Pair a) where
    mempty  = Pair mempty mempty
    (Pair a0 b0) `mappend` (Pair a1 b1) = Pair (a0 `mappend` a1) (b0 `mappend` b1)

instance Distributive Pair where
  collect f x = Pair (getL . f <$> x) (getR . f <$> x)
    where getL (Pair l _) = l
          getR (Pair _ r) = r

instance Representable Pair where
  type Rep Pair = Bool
  tabulate f = Pair (f False) (f True)
  index (Pair l _) False = l
  index (Pair _ r) True = r

instance NFData a => NFData (Pair a) where
  rnf (Pair a b) = rnf a `seq` rnf b

instance (Arbitrary a) => Arbitrary (Pair a) where
    arbitrary = do
        a <- arbitrary
        b <- arbitrary
        pure (Pair a b)

-- * numeric heirarchy
instance (AdditiveMagma a) => AdditiveMagma (Pair a) where
    plus (Pair a0 b0) (Pair a1 b1) = Pair (a0 `plus` a1) (b0 `plus` b1)

instance (AdditiveUnital a) => AdditiveUnital (Pair a) where
    zero = Pair zero zero

instance (AdditiveMagma a) => AdditiveAssociative (Pair a)
instance (AdditiveMagma a) => AdditiveCommutative (Pair a)
instance (AdditiveUnital a) => Additive (Pair a)

instance (AdditiveInvertible a) => AdditiveInvertible (Pair a) where
    negate (Pair a b) = Pair (negate a) (negate b)

instance (AdditiveUnital a, AdditiveInvertible a ) => AdditiveGroup (Pair a)

instance (MultiplicativeMagma a) => MultiplicativeMagma (Pair a) where
    times (Pair a0 b0) (Pair a1 b1) = Pair (a0 `times` a1) (b0 `times` b1)

instance (MultiplicativeUnital a) => MultiplicativeUnital (Pair a) where
    one = Pair one one

instance (MultiplicativeMagma a) => MultiplicativeAssociative (Pair a)
instance (MultiplicativeMagma a) => MultiplicativeCommutative (Pair a)
instance (MultiplicativeUnital a) => Multiplicative (Pair a)

instance (MultiplicativeInvertible a) => MultiplicativeInvertible (Pair a) where
    recip (Pair a b) = Pair (recip a) (recip b)

instance (MultiplicativeUnital a, MultiplicativeInvertible a ) => MultiplicativeGroup (Pair a)

-- | integral instance
instance (Integral a) => Integral (Pair a) where
    (Pair a0 b0) `divMod` (Pair a1 b1) = (Pair da db, Pair ma mb)
      where
        (da,ma) = a0 `divMod` a1
        (db,mb) = b0 `divMod` b1

-- metric instances
instance (Signed a) => Signed (Pair a) where
    sign (Pair a b) = Pair (sign a) (sign b)
    abs (Pair a b) = Pair (abs a) (abs b)

instance (ExpField a, AdditiveGroup a, MultiplicativeUnital a) => Normed (Pair a) a where
    size (Pair a b) = sqrt (a**(one+one) + b**(one+one))

instance (Epsilon a) => Epsilon (Pair a) where
    nearZero (Pair a b) = nearZero a && nearZero b
    aboutEqual a b = nearZero $ a - b

instance (ExpField a) => Metric (Pair a) a where
    distance (Pair a0 b0) (Pair a1 b1) = size (Pair (a1-a0) (b1-b0))

-- | ring instances
instance (AdditiveGroup a, Distribution a) =>
    Distribution (Pair a)
instance (Ring a) => Ring (Pair a)
instance (AdditiveGroup a, Semiring a) => Semiring (Pair a)
instance (CRing a) => CRing (Pair a)
instance (Field a) => Field (Pair a)
instance (ExpField a) => ExpField (Pair a) where
    exp (Pair a b) = Pair (exp a) (exp b)
    log (Pair a b) = Pair (log a) (log b)
instance (BoundedField a) => BoundedField (Pair a) where
    isNaN (Pair a b) = isNaN a || isNaN b