module Data.Category.Monoidal where
import Prelude (($), uncurry)
import qualified Control.Monad as M
import qualified Data.Monoid as M
import Data.Category
import Data.Category.Functor
import Data.Category.NaturalTransformation
import Data.Category.Product
import Data.Category.Limit
class Functor f => HasUnit f where
  
  type Unit f :: *
  unitObject :: f -> Obj (Cod f) (Unit f)
instance (HasTerminalObject (~>), HasBinaryProducts (~>)) => HasUnit (ProductFunctor (~>)) where
  
  type Unit (ProductFunctor (~>)) = TerminalObject (~>)
  unitObject _ = terminalObject
instance (HasInitialObject (~>), HasBinaryCoproducts (~>)) => HasUnit (CoproductFunctor (~>)) where
  
  type Unit (CoproductFunctor (~>)) = InitialObject (~>)
  unitObject _ = initialObject
instance Category (~>) => HasUnit (FunctorCompose (~>)) where
  
  type Unit (FunctorCompose (~>)) = Id (~>)
  unitObject _ = natId Id
  
class HasUnit f => TensorProduct f where
  
  leftUnitor     :: Cod f ~ (~>) => f -> Obj (Cod f) a -> (f :% (Unit f, a)) ~> a
  leftUnitorInv  :: Cod f ~ (~>) => f -> Obj (Cod f) a -> a ~> (f :% (Unit f, a))
  rightUnitor    :: Cod f ~ (~>) => f -> Obj (Cod f) a -> (f :% (a, Unit f)) ~> a
  rightUnitorInv :: Cod f ~ (~>) => f -> Obj (Cod f) a -> a ~> (f :% (a, Unit f))
  
  associator     :: Cod f ~ (~>) => f -> Obj (Cod f) a -> Obj (Cod f) b -> Obj (Cod f) c -> (f :% (f :% (a, b), c)) ~> (f :% (a, f :% (b, c)))
  associatorInv  :: Cod f ~ (~>) => f -> Obj (Cod f) a -> Obj (Cod f) b -> Obj (Cod f) c -> (f :% (a, f :% (b, c))) ~> (f :% (f :% (a, b), c))
instance (HasTerminalObject (~>), HasBinaryProducts (~>)) => TensorProduct (ProductFunctor (~>)) where
  
  leftUnitor     _ a = proj2 terminalObject a
  leftUnitorInv  _ a = terminate a &&& a
  rightUnitor    _ a = proj1 a terminalObject
  rightUnitorInv _ a = a &&& terminate a
  associator    _ a b c = (proj1 a b . proj1 (a *** b) c) &&& (proj2 a b *** c)
  associatorInv _ a b c = (a *** proj1 b c) &&& (proj2 b c . proj2 a (b *** c))
instance (HasInitialObject (~>), HasBinaryCoproducts (~>)) => TensorProduct (CoproductFunctor (~>)) where
  
  leftUnitor     _ a = initialize a ||| a
  leftUnitorInv  _ a = inj2 initialObject a
  rightUnitor    _ a = a ||| initialize a
  rightUnitorInv _ a = inj1 a initialObject
  
  associator    _ a b c = (a +++ inj1 b c) ||| (inj2 a (b +++ c) . inj2 b c)
  associatorInv _ a b c = (inj1 (a +++ b) c . inj1 a b) ||| (inj2 a b +++ c)
  
instance Category (~>) => TensorProduct (FunctorCompose (~>)) where
  
  leftUnitor     _ (Nat g _ _) = Nat (Id :.: g) g $ \i -> g % i
  leftUnitorInv  _ (Nat g _ _) = Nat g (Id :.: g) $ \i -> g % i
  rightUnitor    _ (Nat g _ _) = Nat (g :.: Id) g $ \i -> g % i
  rightUnitorInv _ (Nat g _ _) = Nat g (g :.: Id) $ \i -> g % i
  associator    _ (Nat f _ _) (Nat g _ _) (Nat h _ _) = Nat ((f :.: g) :.: h) (f :.: (g :.: h)) $ \i -> f % g % h % i
  associatorInv _ (Nat f _ _) (Nat g _ _) (Nat h _ _) = Nat (f :.: (g :.: h)) ((f :.: g) :.: h) $ \i -> f % g % h % i
data MonoidObject f a = MonoidObject
  { unit     :: (Cod f ~ (~>)) => Unit f        ~> a
  , multiply :: (Cod f ~ (~>)) => (f :% (a, a)) ~> a
  }
  
data ComonoidObject f a = ComonoidObject
  { counit     :: (Cod f ~ (~>)) => a ~> Unit f
  , comultiply :: (Cod f ~ (~>)) => a ~> (f :% (a, a))
  }
preludeMonoid :: M.Monoid m => MonoidObject (ProductFunctor (->)) m
preludeMonoid = MonoidObject M.mempty (uncurry M.mappend)
data MonoidAsCategory f m a b where
  MonoidValue :: (TensorProduct f , Dom f ~ ((~>) :**: (~>)), Cod f ~ (~>))
              => f -> MonoidObject f m -> Unit f ~> m -> MonoidAsCategory f m m m
instance Category (MonoidAsCategory f m) where
  
  src (MonoidValue f m _) = MonoidValue f m $ unit m
  tgt (MonoidValue f m _) = MonoidValue f m $ unit m
  
  MonoidValue f m a . MonoidValue _ _ b = MonoidValue f m $ multiply m . f % (a :**: b) . leftUnitorInv f (unitObject f)
type Monad f = MonoidObject (FunctorCompose (Dom f)) f
mkMonad :: (Functor f, Dom f ~ (~>), Cod f ~ (~>), Category (~>)) 
  => f 
  -> (forall a. Obj (~>) a -> Component (Id (~>)) f a) 
  -> (forall a. Obj (~>) a -> Component (f :.: f) f a)
  -> Monad f
mkMonad f ret join = MonoidObject
  { unit     = Nat Id        f ret
  , multiply = Nat (f :.: f) f join
  }
preludeMonad :: (M.Functor f, M.Monad f) => Monad (EndoHask f)
preludeMonad = mkMonad EndoHask (\_ -> M.return) (\_ -> M.join)
monadFunctor :: forall f. Monad f -> f
monadFunctor (unit -> Nat _ f _) = f
type Comonad f = ComonoidObject (FunctorCompose (Dom f)) f
mkComonad :: (Functor f, Dom f ~ (~>), Cod f ~ (~>), Category (~>)) 
  => f 
  -> (forall a. Obj (~>) a -> Component f (Id (~>)) a) 
  -> (forall a. Obj (~>) a -> Component f (f :.: f) a)
  -> Comonad f
mkComonad f extr dupl = ComonoidObject
  { counit     = Nat f Id        extr
  , comultiply = Nat f (f :.: f) dupl
  }