{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
#if ( __GLASGOW_HASKELL__ < 820 )
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-unrecognised-pragmas #-}
#endif

-- | I would have used V2 from the linear package, but wanted to avoid the lens dependency. And there's no canonical treatment out there.
module NumHask.Pair
  ( Pair(..)
  , pattern Pair
  ) where

import NumHask.Prelude
import Text.Show

import Data.Distributive
import Data.Functor.Apply (Apply(..))
import Data.Functor.Classes
import Data.Semigroup.Foldable (Foldable1(..))
import Data.Semigroup.Traversable (Traversable1(..))
import Test.QuickCheck.Arbitrary (Arbitrary(..))

-- $setup
-- >>> :set -XNoImplicitPrelude
--

-- | A pair of a's, implemented as a tuple, but api represented as a Pair of a's.
--
-- >>> fmap (+1) (Pair 1 2)
-- Pair 2 3
-- >>> pure one :: Pair Int
-- Pair 1 1
-- >>> (*) <$> Pair 1 2 <*> pure 2
-- Pair 2 4
-- >>> foldr (++) [] (Pair [1,2] [3])
-- [1,2,3]
-- >>> Pair "a" "pair" `mappend` pure " " `mappend` Pair "string" "mappended"
-- Pair "a string" "pair mappended"
--
-- As a Ring and Field class
-- 
-- >>> Pair 0 1 + zero
-- Pair 0 1
-- >>> Pair 0 1 + Pair 2 3
-- Pair 2 4
-- >>> Pair 1 1 - one
-- Pair 0 0
-- >>> Pair 0 1 * one
-- Pair 0 1
-- >>> Pair 0 1 / one
-- Pair 0.0 1.0
-- >>> Pair 11 12 `mod` (pure 6)
-- Pair 5 0
--
-- As a numhask module
--
-- >>> Pair 1 2 .+ 3
-- Pair 4 5
--
-- representables
--
-- >>>  distribute [Pair 1 2, Pair 3 4]
-- Pair [1,3] [2,4]
-- >>> index (Pair 'l' 'r') False
-- 'l'
newtype Pair a =
  Pair' (a, a)
  deriving (Eq, Generic)

-- | the preferred pattern
pattern Pair :: a -> a -> Pair a
pattern Pair a b = Pair' (a,b)
{-# COMPLETE Pair#-}

instance (Show a) => Show (Pair a) where
  show (Pair a b) = "Pair " <> Text.Show.show a <> " " <> Text.Show.show b

instance Functor Pair where
  fmap f (Pair a b) = Pair (f a) (f b)

instance Eq1 Pair where
  liftEq f (Pair a b) (Pair c d) = f a c && f b d

instance Show1 Pair where
  liftShowsPrec sp _ d (Pair' (a, b)) = showsBinaryWith sp sp "Pair" d a b

instance Apply Pair where
  Pair fa fb <.> Pair a b = Pair (fa a) (fb b)

instance Applicative Pair where
  pure a = Pair a a
  (Pair fa fb) <*> Pair a b = Pair (fa a) (fb b)

instance Monad Pair where
  Pair a b >>= f = Pair a' b'
    where
      Pair a' _ = f a
      Pair _ b' = f b

instance Foldable Pair where
  foldMap f (Pair a b) = f a `mappend` f b

instance Foldable1 Pair
    -- foldMap1 f (Pair a b) = f a <> f b

instance Traversable Pair where
  traverse f (Pair a b) = Pair <$> f a <*> f b

instance Traversable1 Pair where
  traverse1 f (Pair a b) = Pair <$> f a Data.Functor.Apply.<.> f b

instance (Monoid a) => Monoid (Pair a) where
  mempty = Pair mempty mempty
  (Pair a0 b0) `mappend` (Pair a1 b1) = Pair (a0 `mappend` a1) (b0 `mappend` b1)

instance Distributive Pair where
  collect f x = Pair (getL . f <$> x) (getR . f <$> x)
    where
      getL (Pair l _) = l
      getR (Pair _ r) = r

instance Representable Pair where
  type Rep Pair = Bool
  tabulate f = Pair (f False) (f True)
  index (Pair l _) False = l
  index (Pair _ r) True = r

instance NFData a => NFData (Pair a) where
  rnf (Pair a b) = rnf a `seq` rnf b

instance (Arbitrary a) => Arbitrary (Pair a) where
  arbitrary = do
    a <- arbitrary
    b <- arbitrary
    pure (Pair a b)

-- numeric heirarchy
instance (AdditiveMagma a) => AdditiveMagma (Pair a) where
  plus (Pair a0 b0) (Pair a1 b1) = Pair (a0 `plus` a1) (b0 `plus` b1)

instance (AdditiveUnital a) => AdditiveUnital (Pair a) where
  zero = Pair zero zero

instance (AdditiveMagma a) => AdditiveAssociative (Pair a)

instance (AdditiveMagma a) => AdditiveCommutative (Pair a)

instance (AdditiveUnital a) => Additive (Pair a)

instance (AdditiveInvertible a) => AdditiveInvertible (Pair a) where
  negate (Pair a b) = Pair (negate a) (negate b)

instance (AdditiveUnital a, AdditiveInvertible a) =>
         AdditiveGroup (Pair a)

instance (MultiplicativeMagma a) => MultiplicativeMagma (Pair a) where
  times (Pair a0 b0) (Pair a1 b1) = Pair (a0 `times` a1) (b0 `times` b1)

instance (MultiplicativeUnital a) => MultiplicativeUnital (Pair a) where
  one = Pair one one

instance (MultiplicativeMagma a) => MultiplicativeAssociative (Pair a)

instance (MultiplicativeMagma a) => MultiplicativeCommutative (Pair a)

instance (MultiplicativeUnital a) => Multiplicative (Pair a)

instance (MultiplicativeInvertible a) => MultiplicativeInvertible (Pair a) where
  recip (Pair a b) = Pair (recip a) (recip b)

instance (MultiplicativeUnital a, MultiplicativeInvertible a) =>
         MultiplicativeGroup (Pair a)

instance (Integral a) => Integral (Pair a) where
  (Pair a0 b0) `divMod` (Pair a1 b1) = (Pair da db, Pair ma mb)
    where
      (da, ma) = a0 `divMod` a1
      (db, mb) = b0 `divMod` b1

instance (Signed a) => Signed (Pair a) where
  sign (Pair a b) = Pair (sign a) (sign b)
  abs (Pair a b) = Pair (abs a) (abs b)

instance (ExpField a, AdditiveGroup a, MultiplicativeUnital a) =>
         Normed (Pair a) a where
  size (Pair a b) = sqrt (a ** (one + one) + b ** (one + one))

-- | L1-based Ord instance
instance (Eq a, Ord a, Signed a, Additive a) => Ord (Pair a) where
  (<=) (Pair x y) (Pair x' y') = (abs x + abs y) <= (abs x' + abs y')

instance (Epsilon a) => Epsilon (Pair a) where
  nearZero (Pair a b) = nearZero a && nearZero b
  aboutEqual a b = nearZero $ a - b

instance (ExpField a) => Metric (Pair a) a where
  distance (Pair a0 b0) (Pair a1 b1) = size (Pair (a1 - a0) (b1 - b0))

instance (AdditiveGroup a, Distribution a) => Distribution (Pair a)

instance (Ring a) => Ring (Pair a)

instance (AdditiveGroup a, Semiring a) => Semiring (Pair a)

instance (CRing a) => CRing (Pair a)

instance (Field a) => Field (Pair a)

instance (ExpField a) => ExpField (Pair a) where
  exp (Pair a b) = Pair (exp a) (exp b)
  log (Pair a b) = Pair (log a) (log b)

instance (BoundedField a) => BoundedField (Pair a) where
  isNaN (Pair a b) = isNaN a || isNaN b