{-# LANGUAGE KindSignatures, PolyKinds, MultiParamTypeClasses, FunctionalDependencies, ConstraintKinds, NoImplicitPrelude, TypeFamilies, TypeOperators, FlexibleContexts, FlexibleInstances, UndecidableInstances, RankNTypes, GADTs, ScopedTypeVariables, DataKinds, AllowAmbiguousTypes, LambdaCase, DefaultSignatures, EmptyCase #-}
module Hask.Tensor.Compose where

import Data.Constraint.Unsafe (unsafeCoerceConstraint)
import GHC.Prim (Any)
import Hask.Category
import Hask.Iso
import Hask.Tensor
import Unsafe.Coerce (unsafeCoerce)

--------------------------------------------------------------------------------
-- * Compose
--------------------------------------------------------------------------------

data COMPOSE = Compose
type Compose = (Any 'Compose :: (i -> i -> *) -> (j -> j -> *) -> (k -> k -> *) -> (j -> k) -> (i -> j) -> i -> k)

class Category e => Composed (e :: k -> k -> *) where
  _Compose :: (FunctorOf d e f, FunctorOf d e f', FunctorOf c d g, FunctorOf c d g') => Iso
    e e (->)
    (Compose c d e f g a) (Compose c d e f' g' a')
    (f (g a))             (f' (g' a'))

instance Composed (->) where
  _Compose = unsafeCoerce

instance Composed (:-) where
  _Compose = unsafeCoerce

instance (Category c, Composed d) => Composed (Nat c d) where
  _Compose = unsafeCoerce -- really evil, like super-villain evil

instance (Category c, Category d, Category e) => Class (f (g a)) (Compose c d e f g a) where cls = unsafeCoerceConstraint
instance (Category c, Category d, Category e) => f (g a) :=> Compose c d e f g a where ins = unsafeCoerceConstraint

instance (Category c, Category d, Composed e) => Functor (Compose c d e) where
  type Dom (Compose c d e) = Nat d e
  type Cod (Compose c d e) = Nat (Nat c d) (Nat c e)
  fmap = fmap' where
    fmap' :: Nat d e a b -> Nat (Nat c d) (Nat c e) (Compose c d e a) (Compose c d e b)
    fmap' n@Nat{} = nat $ \g -> nat $ \a -> _Compose $ n ! g ! a

instance (Category c, Category d, Composed e, Functor f, e ~ Cod f, d ~ Dom f) => Functor (Compose c d e f) where
  type Dom (Compose c d e f) = Nat c d
  type Cod (Compose c d e f) = Nat c e
  fmap (Nat f) = Nat $ _Compose $ fmap f

instance (Category c, Category d, Composed e, Functor f, Functor g, e ~ Cod f, d ~ Cod g, d ~ Dom f, c ~ Dom g) => Functor (Compose c d e f g) where
  type Dom (Compose c d e f g) = c
  type Cod (Compose c d e f g) = e
  fmap f = _Compose $ fmap $ fmap f

instance (Composed c, c ~ c', c' ~ c'') => Semitensor (Compose c c' c'' :: (i -> i) -> (i -> i) -> (i -> i)) where
  associate = associateCompose

data ID = Id
type Id = (Any 'Id :: (i -> i -> *) -> i -> i)

class Category c => Identified (c :: i -> i -> *) where
  _Id :: Iso c c (->) (Id c a) (Id c a') a a'

instance Identified (->) where
  _Id = unsafeCoerce

instance Identified (:-) where
  _Id = unsafeCoerce

instance (Category c, Identified d) => Identified (Nat c d) where
  _Id = unsafeCoerce

instance Category c => Class a (Id c a) where cls = unsafeCoerceConstraint
instance Category c => a :=> Id c a where ins = unsafeCoerceConstraint

instance Identified c => Functor (Id c) where
  type Dom (Id c) = c
  type Cod (Id c) = c
  fmap = _Id

type instance I (Compose c c c) = Id c

instance (Identified c, Composed c) => Semigroup (Compose c c c) (Id c) where
  mu = dimap (get lambda) id id

instance (Identified c, Composed c) => Monoid' (Compose c c c) (Id c) where
  eta _ = Nat $ _Id id

instance (Identified c, Composed c) => Cosemigroup (Compose c c c) (Id c) where
  delta = dimap id (beget lambda) id

instance (Identified c, Composed c) => Comonoid' (Compose c c c) (Id c) where
  epsilon _ = Nat $ _Id id

instance (Identified c, Composed c) => Tensor' (Compose c c c :: (i -> i) -> (i -> i) -> (i -> i)) where
  lambda = lambdaCompose
  rho = rhoCompose

associateCompose :: forall b c d e f g h f' g' h'.
   ( Category b, Category c, Composed d, Composed e
   , FunctorOf d e f, FunctorOf c d g, FunctorOf b c h
   , FunctorOf d e f', FunctorOf c d g', FunctorOf b c h'
   ) => Iso (Nat b e) (Nat b e) (->)
  (Compose b c e (Compose c d e f g) h) (Compose b c e (Compose c d e f' g') h')
  (Compose b d e f (Compose b c d g h)) (Compose b d e f' (Compose b c d g' h'))
associateCompose = dimap hither yon where
  hither = nat $ \a -> case obOf (id :: NatId f) $ (id :: NatId g) ! (id :: NatId h) ! a of
    Dict -> case obOf (id :: NatId f) $ (id :: NatId (Compose b c d g h)) ! a of
      Dict -> case obOf (id :: NatId (Compose c d e f g)) $ (id :: NatId h) ! a of
        Dict -> beget _Compose . fmap (beget _Compose) . get _Compose . get _Compose
  yon = nat $ \a -> case obOf (id :: NatId f') $ (id :: NatId g') ! (id :: NatId h') ! a of
    Dict -> case obOf (id :: NatId f') $ (id :: NatId (Compose b c d g' h')) ! a of
      Dict -> case obOf (id :: NatId (Compose c d e f' g')) $ (id :: NatId h') ! a of
        Dict -> beget _Compose . beget _Compose . fmap (get _Compose) . get _Compose

lambdaCompose :: forall a a' c. (Identified c, Composed c, Ob (Nat c c) a, Ob (Nat c c) a')
              => Iso (Nat c c) (Nat c c) (->) (Compose c c c (Id c) a) (Compose c c c (Id c) a') a a'
lambdaCompose = dimap hither yon where
  hither = nat $ \z -> case obOf (id :: NatId (Id c)) $ (id :: NatId a) ! z of
    Dict -> get _Id . get _Compose
  yon = nat $ \z -> case obOf (id :: NatId (Id c)) $ (id :: NatId a') ! z of
    Dict -> beget _Compose . beget _Id

rhoCompose :: forall a a' c. (Identified c, Composed c, Ob (Nat c c) a, Ob (Nat c c) a')
           => Iso (Nat c c) (Nat c c) (->) (Compose c c c a (Id c)) (Compose c c c a' (Id c)) a a'
rhoCompose = dimap hither yon where
  hither = nat $ \z -> case obOf (id :: NatId a) $ (id :: NatId (Id c)) ! z of
    Dict -> fmap (get _Id) . get _Compose
  yon = nat $ \z -> case obOf (id :: NatId a') $ (id :: NatId (Id c)) ! z of
    Dict -> beget _Compose . fmap (beget _Id)

--------------------------------------------------------------------------------
-- ** Monads
--------------------------------------------------------------------------------

class    (Functor m, Dom m ~ Cod m, Monoid (Compose (Dom m) (Dom m) (Dom m)) m, Identified (Dom m), Composed (Dom m)) => Monad m
instance (Functor m, Dom m ~ Cod m, Monoid (Compose (Dom m) (Dom m) (Dom m)) m, Identified (Dom m), Composed (Dom m)) => Monad m

return :: forall m a. (Monad m, Ob (Dom m) a) => Dom m a (m a)
return = runNat (eta (id :: NatId (Compose (Dom m) (Dom m) (Dom m)))) . beget _Id

bind :: forall m a b. (Monad m, Ob (Dom m) b) => Dom m a (m b) -> Dom m (m a) (m b)
bind f = case observe f of
  Dict -> case obOf (id :: NatId m) (id :: Endo (Cod m) (m b)) of
    Dict -> runNat mu . beget _Compose . fmap f