-- |
-- The base representation for bidirectional arrows (bijections).
{-# LANGUAGE Trustworthy, TypeOperators, KindSignatures, FlexibleInstances, CPP #-}
module Data.Invertible.Bijection
  ( Bijection(..)
  , type (<->)
  ) where

import Prelude hiding (id, (.))
import Control.Category
import Control.Arrow
#ifdef VERSION_semigroupoids
import Data.Semigroupoid (Semigroupoid(..))
import Data.Groupoid (Groupoid(..))
#endif
#ifdef VERSION_invariant
import Data.Functor.Invariant (Invariant(..), Invariant2(..))
#endif

infix 2 <->, :<->:

-- |A representation of a bidirectional arrow (embedding-projection pair of arrows transformer): an arrow and its inverse.
-- Most uses will prefer the specialized '<->' type for function arrows.
--
-- To constitute a valid bijection, 'biTo' and 'biFrom' should be inverses:
--
--  * @biTo . biFrom = id@
--  * @biFrom . biTo = id@
--
-- It may be argued that the arguments should be in the opposite order due to the arrow syntax, but it makes more sense to me to have the forward function come first.
data Bijection (a :: * -> * -> *) b c = (:<->:)
  { forall (a :: * -> * -> *) b c. Bijection a b c -> a b c
biTo :: a b c
  , forall (a :: * -> * -> *) b c. Bijection a b c -> a c b
biFrom :: a c b
  }

-- |Specialization of 'Bijection' to function arrows.
-- Represents both a function, @f@, and its (presumed) inverse, @g@, represented as @f ':<->:' g@.
type (<->) = Bijection (->)

instance Category a => Category (Bijection a) where
  id :: forall a. Bijection a a a
id = forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id
  (a b c
f1 :<->: a c b
g1) . :: forall b c a. Bijection a b c -> Bijection a a b -> Bijection a a c
. (a a b
f2 :<->: a b a
g2) = a b c
f1 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a a b
f2 forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: a b a
g2 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a c b
g1

-- |In order to use all the 'Arrow' functions, we make a partially broken instance, where 'arr' creates a bijection with a broken 'biFrom'. See note on 'Control.Invertible.BiArrow.BiArrow''.
-- '&&&' is first-biased, and uses only the left argument's 'biFrom'.
instance Arrow a => Arrow (Bijection a) where
  arr :: forall b c. (b -> c) -> Bijection a b c
arr b -> c
f = forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr b -> c
f forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr (forall a b. a -> b -> a
const (forall a. HasCallStack => [Char] -> a
error [Char]
"Bijection: arr has no inverse"))
  first :: forall b c d. Bijection a b c -> Bijection a (b, d) (c, d)
first  (a b c
f :<->: a c b
g) = forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first a b c
f  forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first a c b
g
  second :: forall b c d. Bijection a b c -> Bijection a (d, b) (d, c)
second (a b c
f :<->: a c b
g) = forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second a b c
f forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second a c b
g
  (a b c
f :<->: a c b
g) *** :: forall b c b' c'.
Bijection a b c -> Bijection a b' c' -> Bijection a (b, b') (c, c')
*** (a b' c'
f' :<->: a c' b'
g') = (a b c
f forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** a b' c'
f') forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: (a c b
g forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** a c' b'
g')
  (a b c
f :<->: a c b
g) &&& :: forall b c c'.
Bijection a b c -> Bijection a b c' -> Bijection a b (c, c')
&&& (a b c'
f' :<->: a c' b
_ ) = (a b c
f forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& a b c'
f') forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: (a c b
g forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr forall a b. (a, b) -> a
fst) -- (g' . arr snd)

-- |'|||' is Left-biased, and uses only the left argument's 'biFrom'.
instance ArrowChoice a => ArrowChoice (Bijection a) where
  left :: forall b c d.
Bijection a b c -> Bijection a (Either b d) (Either c d)
left  (a b c
f :<->: a c b
g) = forall (a :: * -> * -> *) b c d.
ArrowChoice a =>
a b c -> a (Either b d) (Either c d)
left a b c
f  forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: forall (a :: * -> * -> *) b c d.
ArrowChoice a =>
a b c -> a (Either b d) (Either c d)
left a c b
g
  right :: forall b c d.
Bijection a b c -> Bijection a (Either d b) (Either d c)
right (a b c
f :<->: a c b
g) = forall (a :: * -> * -> *) b c d.
ArrowChoice a =>
a b c -> a (Either d b) (Either d c)
right a b c
f forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: forall (a :: * -> * -> *) b c d.
ArrowChoice a =>
a b c -> a (Either d b) (Either d c)
right a c b
g
  (a b c
f :<->: a c b
g) +++ :: forall b c b' c'.
Bijection a b c
-> Bijection a b' c' -> Bijection a (Either b b') (Either c c')
+++ (a b' c'
f' :<->: a c' b'
g') = (a b c
f forall (a :: * -> * -> *) b c b' c'.
ArrowChoice a =>
a b c -> a b' c' -> a (Either b b') (Either c c')
+++ a b' c'
f') forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: (a c b
g forall (a :: * -> * -> *) b c b' c'.
ArrowChoice a =>
a b c -> a b' c' -> a (Either b b') (Either c c')
+++ a c' b'
g')
  (a b d
f :<->: a d b
g) ||| :: forall b d c.
Bijection a b d -> Bijection a c d -> Bijection a (Either b c) d
||| (a c d
f' :<->: a d c
_ ) = (a b d
f forall (a :: * -> * -> *) b d c.
ArrowChoice a =>
a b d -> a c d -> a (Either b c) d
||| a c d
f') forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: (forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr forall a b. a -> Either a b
Left forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a d b
g) -- (arr Right . g')

instance ArrowZero a => ArrowZero (Bijection a) where
  zeroArrow :: forall b c. Bijection a b c
zeroArrow = forall (a :: * -> * -> *) b c. ArrowZero a => a b c
zeroArrow forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: forall (a :: * -> * -> *) b c. ArrowZero a => a b c
zeroArrow

#ifdef VERSION_semigroupoids
instance Semigroupoid a => Semigroupoid (Bijection a) where
  (a j k1
f1 :<->: a k1 j
g1) o :: forall j k1 i.
Bijection a j k1 -> Bijection a i j -> Bijection a i k1
`o` (a i j
f2 :<->: a j i
g2) = (a j k1
f1 forall {k} (c :: k -> k -> *) (j :: k) (k1 :: k) (i :: k).
Semigroupoid c =>
c j k1 -> c i j -> c i k1
`o` a i j
f2) forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: (a j i
g2 forall {k} (c :: k -> k -> *) (j :: k) (k1 :: k) (i :: k).
Semigroupoid c =>
c j k1 -> c i j -> c i k1
`o` a k1 j
g1)

instance Semigroupoid a => Groupoid (Bijection a) where
  inv :: forall a b. Bijection a a b -> Bijection a b a
inv (a a b
f :<->: a b a
g) = a b a
g forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: a a b
f
#endif

#ifdef VERSION_invariant
instance Invariant (Bijection (->) b) where
  invmap :: forall a b.
(a -> b) -> (b -> a) -> Bijection (->) b a -> Bijection (->) b b
invmap = (forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
(.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
(:<->:)

instance Invariant2 (Bijection (->)) where
  invmap2 :: forall a c b d.
(a -> c)
-> (c -> a)
-> (b -> d)
-> (d -> b)
-> Bijection (->) a b
-> Bijection (->) c d
invmap2 a -> c
f c -> a
g = forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
(.) ((forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (c -> a
g forall (a :: * -> * -> *) b c. a b c -> a c b -> Bijection a b c
:<->: a -> c
f)) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (f :: * -> *) a b.
Invariant f =>
(a -> b) -> (b -> a) -> f a -> f b
invmap
#endif