{-# LANGUAGE CPP #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DataKinds #-}
module Data.Semimodule.Index where
import safe Control.Applicative
import safe Control.Category (Category, (>>>))
import safe Data.Algebra
import safe Data.Bool
import safe Data.Distributive
import safe Data.Foldable as Foldable (fold, foldl')
import safe Data.Functor.Rep
import safe Data.Functor.Compose
import safe Data.Functor.Product
import safe Data.Group
import safe Data.Magma
import safe Data.Profunctor
import safe Data.Semifield
import safe Data.Semigroup.Foldable as Foldable1
import safe Data.Semimodule
import safe Data.Semimodule.Basis
import safe Data.Semiring
import safe Prelude hiding (Num(..), Fractional(..), negate, sum, product)
import safe qualified Prelude as P
import qualified Data.Bifunctor as B
import qualified Control.Category as C
import qualified Control.Monad as M (join)
import Data.Tuple (swap)
import Data.Void
import safe Prelude (fromInteger, fromRational)
type (+) = Either
rgt :: (a -> b) -> a + b -> b
rgt f = either f id
{-# INLINE rgt #-}
rgt' :: Void + b -> b
rgt' = rgt absurd
{-# INLINE rgt' #-}
lft :: (b -> a) -> a + b -> a
lft f = either id f
{-# INLINE lft #-}
lft' :: a + Void -> a
lft' = lft absurd
{-# INLINE lft' #-}
eswap :: (a1 + a2) -> (a2 + a1)
eswap (Left x) = Right x
eswap (Right x) = Left x
{-# INLINE eswap #-}
fork :: a -> (a , a)
fork = M.join (,)
{-# INLINE fork #-}
join :: (a + a) -> a
join = M.join either id
{-# INLINE join #-}
eval :: (a , a -> b) -> b
eval = uncurry $ flip id
{-# INLINE eval #-}
apply :: (b -> a , b) -> a
apply = uncurry id
{-# INLINE apply #-}
type Index b c = forall a . Trans a b c
type Endo a b = Trans a b b
newtype Trans a b c = Trans { runTrans :: (c -> a) -> (b -> a) } deriving Functor
instance Category (Trans a) where
id = Trans id
Trans f . Trans g = Trans $ g . f
instance Profunctor (Trans a) where
lmap f (Trans t) = Trans $ \ca -> t ca . f
rmap = fmap
arr :: (b -> c) -> Index b c
arr f = Trans (. f)
app :: Basis b f => Basis c g => Trans a b c -> g a -> f a
app t = tabulate . runTrans t . index
in1 :: Index (a , b) b
in1 = arr snd
{-# INLINE in1 #-}
in2 :: Index (a , b) a
in2 = arr fst
{-# INLINE in2 #-}
exl :: Index a (a + b)
exl = arr Left
{-# INLINE exl #-}
exr :: Index b (a + b)
exr = arr Right
{-# INLINE exr #-}
braid :: Index (a , b) (b , a)
braid = arr swap
{-# INLINE braid #-}
ebraid :: Index (a + b) (b + a)
ebraid = arr eswap
{-# INLINE ebraid #-}
first :: Index b c -> Index (b , d) (c , d)
first (Trans caba) = Trans $ \cda -> cda . B.first (caba id)
second :: Index b c -> Index (d , b) (d , c)
second (Trans caba) = Trans $ \cda -> cda . B.second (caba id)
left :: Index b c -> Index (b + d) (c + d)
left (Trans caba) = Trans $ \cda -> cda . B.first (caba id)
right :: Index b c -> Index (d + b) (d + c)
right (Trans caba) = Trans $ \cda -> cda . B.second (caba id)
infixr 3 ***
(***) :: Index a1 b1 -> Index a2 b2 -> Index (a1 , a2) (b1 , b2)
x *** y = first x >>> arr swap >>> first y >>> arr swap
{-# INLINE (***) #-}
infixr 2 +++
(+++) :: Index a1 b1 -> Index a2 b2 -> Index (a1 + a2) (b1 + b2)
x +++ y = left x >>> arr eswap >>> left y >>> arr eswap
{-# INLINE (+++) #-}
infixr 3 &&&
(&&&) :: Index a b1 -> Index a b2 -> Index a (b1 , b2)
x &&& y = dimap fork id $ x *** y
{-# INLINE (&&&) #-}
infixr 2 |||
(|||) :: Index a1 b -> Index a2 b -> Index (a1 + a2) b
x ||| y = dimap id join $ x +++ y
{-# INLINE (|||) #-}
infixr 0 $$$
($$$) :: Index a (b -> c) -> Index a b -> Index a c
($$$) f x = dimap fork apply (f *** x)
{-# INLINE ($$$) #-}
adivide :: (a -> (a1 , a2)) -> Index a1 b -> Index a2 b -> Index a b
adivide f x y = dimap f fst $ x *** y
{-# INLINE adivide #-}
adivide' :: Index a b -> Index a b -> Index a b
adivide' = adivide fork
{-# INLINE adivide' #-}
adivided :: Index a1 b -> Index a2 b -> Index (a1 , a2) b
adivided = adivide id
{-# INLINE adivided #-}
aselect :: ((b1 + b2) -> b) -> Index a b1 -> Index a b2 -> Index a b
aselect f x y = dimap Left f $ x +++ y
{-# INLINE aselect #-}
aselect' :: Index a b -> Index a b -> Index a b
aselect' = aselect join
{-# INLINE aselect' #-}
aselected :: Index a b1 -> Index a b2 -> Index a (b1 + b2)
aselected = aselect id
{-# INLINE aselected #-}