{-# LANGUAGE CPP                        #-}
{-# LANGUAGE Safe                       #-}
{-# LANGUAGE PolyKinds                  #-}
{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DefaultSignatures          #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE NoImplicitPrelude          #-}
{-# LANGUAGE RebindableSyntax           #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE RankNTypes                 #-}

{-# LANGUAGE DataKinds                 #-}
module Data.Semimodule.Index where

import safe Control.Applicative
import safe Control.Category (Category, (>>>))
import safe Data.Algebra
import safe Data.Bool
import safe Data.Distributive
import safe Data.Foldable as Foldable (fold, foldl')
import safe Data.Functor.Rep
import safe Data.Functor.Compose
import safe Data.Functor.Product
import safe Data.Group
import safe Data.Magma
import safe Data.Profunctor
import safe Data.Semifield
import safe Data.Semigroup.Foldable as Foldable1
import safe Data.Semimodule
import safe Data.Semimodule.Basis
import safe Data.Semiring
import safe Prelude hiding (Num(..), Fractional(..), negate, sum, product)
import safe qualified Prelude as P

import qualified Data.Bifunctor as B
import qualified Control.Category as C
import qualified Control.Monad as M (join)
import Data.Tuple (swap)
import Data.Void

import safe Prelude (fromInteger, fromRational)


-- Direct sum of vector spaces
type (+) = Either

rgt :: (a -> b) -> a + b -> b
rgt f = either f id
{-# INLINE rgt #-}

rgt' :: Void + b -> b
rgt' = rgt absurd
{-# INLINE rgt' #-}

lft :: (b -> a) -> a + b -> a
lft f = either id f
{-# INLINE lft #-}

lft' :: a + Void -> a
lft' = lft absurd
{-# INLINE lft' #-}

eswap :: (a1 + a2) -> (a2 + a1)
eswap (Left x) = Right x
eswap (Right x) = Left x
{-# INLINE eswap #-}

fork :: a -> (a , a)
fork = M.join (,)
{-# INLINE fork #-}

join :: (a + a) -> a
join = M.join either id
{-# INLINE join #-}

eval :: (a , a -> b) -> b
eval = uncurry $ flip id
{-# INLINE eval #-}

apply :: (b -> a , b) -> a
apply = uncurry id
{-# INLINE apply #-}

---------------------------------------------------------------------

-- | A binary relation between two basis indices.
--
-- @ 'Index' b c @ relations correspond to permutations, projections, 
-- and embedding transformations, as well as combinations thereof.
--
-- See also  < https://en.wikipedia.org/wiki/Logical_matrix >.
--
type Index b c = forall a . Trans a b c

-- | An endomorphism over a free semimodule.
--
type Endo a b = Trans a b b

-- | A general linear transformation between free semimodules indexed with bases /b/ and /c/.
--
newtype Trans a b c = Trans { runTrans :: (c -> a) -> (b -> a) } deriving Functor

instance Category (Trans a) where
  id = Trans id
  Trans f . Trans g = Trans $ g . f

instance Profunctor (Trans a) where
  lmap f (Trans t) = Trans $ \ca -> t ca . f
  rmap = fmap

-- arr toE3 :: Dim3 e => Index e E3

-- @ 'arr' f = 'rmap' f 'C.id' @
arr :: (b -> c) -> Index b c
arr f = Trans (. f)

app :: Basis b f => Basis c g => Trans a b c -> g a -> f a
app t = tabulate . runTrans t . index

---------------------------------------------------------------------

in1 :: Index (a , b) b
in1 = arr snd
{-# INLINE in1 #-}

in2 :: Index (a , b) a
in2 = arr fst
{-# INLINE in2 #-}

exl :: Index a (a + b)
exl = arr Left
{-# INLINE exl #-}

exr :: Index b (a + b)
exr = arr Right
{-# INLINE exr #-}

braid :: Index (a , b) (b , a)
braid = arr swap
{-# INLINE braid #-}

ebraid :: Index (a + b) (b + a)
ebraid = arr eswap
{-# INLINE ebraid #-}

first :: Index b c -> Index (b , d) (c , d)
first (Trans caba) = Trans $ \cda -> cda . B.first (caba id)

second :: Index b c -> Index (d , b) (d , c)
second (Trans caba) = Trans $ \cda -> cda . B.second (caba id)

left :: Index b c -> Index (b + d) (c + d)
left (Trans caba) = Trans $ \cda -> cda . B.first (caba id)

right :: Index b c -> Index (d + b) (d + c)
right (Trans caba) = Trans $ \cda -> cda . B.second (caba id)

infixr 3 ***

(***) :: Index a1 b1 -> Index a2 b2 -> Index (a1 , a2) (b1 , b2)
x *** y = first x >>> arr swap >>> first y >>> arr swap
{-# INLINE (***) #-}

infixr 2 +++

(+++) :: Index a1 b1 -> Index a2 b2 -> Index (a1 + a2) (b1 + b2)
x +++ y = left x >>> arr eswap >>> left y >>> arr eswap
{-# INLINE (+++) #-}

infixr 3 &&&

(&&&) :: Index a b1 -> Index a b2 -> Index a (b1 , b2)
x &&& y = dimap fork id $ x *** y
{-# INLINE (&&&) #-}

infixr 2 |||

(|||) :: Index a1 b -> Index a2 b -> Index (a1 + a2) b
x ||| y = dimap id join $ x +++ y
{-# INLINE (|||) #-}

infixr 0 $$$

($$$) :: Index a (b -> c) -> Index a b -> Index a c
($$$) f x = dimap fork apply (f *** x)
{-# INLINE ($$$) #-}

adivide :: (a -> (a1 , a2)) -> Index a1 b -> Index a2 b -> Index a b
adivide f x y = dimap f fst $ x *** y
{-# INLINE adivide #-}

adivide' :: Index a b -> Index a b -> Index a b
adivide' = adivide fork
{-# INLINE adivide' #-}

adivided :: Index a1 b -> Index a2 b -> Index (a1 , a2) b
adivided = adivide id
{-# INLINE adivided #-}

aselect :: ((b1 + b2) -> b) -> Index a b1 -> Index a b2 -> Index a b
aselect f x y = dimap Left f $ x +++ y
{-# INLINE aselect #-}

aselect' :: Index a b -> Index a b -> Index a b
aselect' = aselect join
{-# INLINE aselect' #-}

aselected :: Index a b1 -> Index a b2 -> Index a (b1 + b2)
aselected = aselect id
{-# INLINE aselected #-}