{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wall #-}
#if MIN_VERSION_base(4,9,0)
#define LIFTED_FUNCTOR_CLASSES 1
#else
#if MIN_VERSION_transformers(0,5,0)
#define LIFTED_FUNCTOR_CLASSES 1
#else
#if MIN_VERSION_transformers_compat(0,5,0) && !MIN_VERSION_transformers(0,4,0)
#define LIFTED_FUNCTOR_CLASSES 1
#endif
#endif
#endif
module Data.UniformPair
  ( Pair(..), fstP,sndP, firstP, secondP, getP, onElemP, swapP, compareSwap
  ) where
import Data.Data (Data, Typeable)
import Data.Monoid (Monoid(..))
import Data.Semigroup (Semigroup (..))
import Data.Functor ((<$>))
import Data.Foldable (Foldable(..))
import Data.Traversable (Traversable(..))
import Data.Functor.Classes (Eq1(..), Ord1(..), Show1(..))
import Control.Applicative (Applicative(..)) 
import Control.DeepSeq (NFData(..))
import Data.Distributive (Distributive(..))
import Data.Functor.Rep (Representable(..),distributeRep)
import qualified Prelude.Extras as PE (Eq1, Ord1, Show1)
infix 1 :#
data Pair a = a :# a
  deriving (Data, Eq, Ord, Show, Functor, Foldable, Traversable, Typeable)
instance NFData a => NFData (Pair a) where
    rnf (a :# b) = rnf a `seq` rnf b
instance PE.Eq1 Pair
instance PE.Ord1 Pair
instance PE.Show1 Pair
#if LIFTED_FUNCTOR_CLASSES
instance Eq1 Pair where
  liftEq eq (a :# b) (c :# d) = eq a c && eq b d
instance Ord1 Pair where
  liftCompare cmp (a :# b) (c :# d) = cmp a c `mappend` cmp b d
instance Show1 Pair where
  liftShowsPrec sp _sl d (a :# b) = showParen (d > 1) $
    sp 2 a . showString " :# " . sp 2 b
#else
instance Eq1 Pair where
  eq1 = (==)
instance Ord1 Pair where
  compare1 = compare
instance Show1 Pair where
  showsPrec1 = showsPrec
#endif
fstP :: Pair a -> a
fstP (a :# _) = a
sndP :: Pair a -> a
sndP (_ :# b) = b
firstP, secondP :: (a -> a) -> (Pair a -> Pair a)
firstP  f ~(a :# b) = f a :# b
secondP g ~(a :# b) = a :# g b
instance Semigroup a => Semigroup (Pair a) where
  (a :# b) <> (c :# d) = (a <> c) :# (b <> d)
instance Monoid a => Monoid (Pair a) where
  mempty = mempty :# mempty
  (a :# b) `mappend` (c :# d) = (a `mappend` c) :# (b `mappend` d)  
instance Applicative Pair where
  pure a = a :# a
  (f :# g) <*> (a :# b) = f a :# g b
instance Monad Pair where
  return = pure
  m >>= f = joinP (f <$> m)
joinP :: Pair (Pair a) -> Pair a
joinP ((a :# _) :# (_ :# d)) = a :# d
instance Distributive Pair where
  distribute = distributeRep
instance Representable Pair where
  type Rep Pair = Bool
  tabulate f = f False :# f True
  index (f :# _) False = f
  index (_ :# t) True  = t
onElemP :: Bool -> (a -> a) -> Pair a -> Pair a
onElemP c f ~(a :# b) | c         = f a :# b
                      | otherwise = a :# f b
getP :: Bool -> Pair a -> a
getP False = fstP
getP True  = sndP
swapP :: Pair a -> Pair a
swapP (a :# b) = b :# a
compareSwap :: Ord a => Pair a -> Pair a
compareSwap (a :# b) | a <= b    = a :# b
                     | otherwise = b :# a