{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE UndecidableInstances #-}

-- This Source Code Form is subject to the terms of the Mozilla Public
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at https://mozilla.org/MPL/2.0/.

{- |
Copyright   :  (c) 2023 Yamada Ryo
License     :  MPL-2.0 (see the file LICENSE)
Maintainer  :  ymdfield@outlook.jp
Stability   :  experimental
Portability :  portable

An implementation of an open union for higher-order effects using
the [extensible](https://hackage.haskell.org/package/extensible) package as a backend.
-}
module Data.Hefty.Extensible (
    module Data.Hefty.Extensible,
    Forall,
) where

import Control.Effect.Free qualified as E
import Control.Effect.Hefty qualified as E
import Data.Effect (SigClass)
import Data.Effect.HFunctor (HFunctor, hfmap)
import Data.Extensible (Forall, Match (Match), htabulateFor, match)
import Data.Extensible.Sum (strikeAt, (<:|), type (:/) (EmbedAt))
import Data.Extensible.Sum qualified as E
import Data.Hefty.Union (
    ClassIndex,
    HFunctorUnion_ (ForallHFunctor),
    Union (
        HasMembership,
        exhaust,
        inject,
        inject0,
        project,
        weaken,
        (|+:)
    ),
 )
import Data.Hefty.Union qualified as U
import Data.Hefty.Union qualified as Union
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality ((:~:) (Refl))
import GHC.TypeNats (KnownNat)
import Type.Membership.Internal (
    Elaborate,
    Elaborated (Expecting),
    FindType,
    Membership,
    leadership,
    membership,
    nextMembership,
 )
import Unsafe.Coerce (unsafeCoerce)

{- |
An implementation of an open union for higher-order effects using
the [extensible](https://hackage.haskell.org/package/extensible) package as a backend.
-}
newtype ExtensibleUnion es f a = ExtensibleUnion {forall (es :: [SigClass]) (f :: * -> *) a.
ExtensibleUnion es f a -> es :/ FieldApp f a
unExtensibleUnion :: es :/ FieldApp f a}

newtype FieldApp f a (e :: SigClass) = FieldApp {forall (f :: * -> *) a (e :: SigClass). FieldApp f a e -> e f a
unFieldApp :: e f a}

instance Forall HFunctor es => HFunctor (ExtensibleUnion es) where
    hfmap :: forall (f :: * -> *) (g :: * -> *).
(f :-> g) -> ExtensibleUnion es f :-> ExtensibleUnion es g
hfmap f :-> g
f =
        forall (es :: [SigClass]) (f :: * -> *) a.
(es :/ FieldApp f a) -> ExtensibleUnion es f a
ExtensibleUnion
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (xs :: [k]) (h :: k -> *) a.
(xs :& Match h a) -> (xs :/ h) -> a
match
                ( forall {k} (c :: k -> Constraint) (xs :: [k])
       (proxy :: (k -> Constraint) -> *) (h :: k -> *).
Forall c xs =>
proxy c
-> (forall (x :: k). c x => Membership xs x -> h x) -> xs :& h
htabulateFor @HFunctor forall {k} (t :: k). Proxy t
Proxy \Membership es x
w ->
                    forall {k} (h :: k -> *) r (x :: k). (h x -> r) -> Match h r x
Match forall a b. (a -> b) -> a -> b
$ forall {k} (xs :: [k]) (x :: k) (h :: k -> *).
Membership xs x -> h x -> xs :/ h
EmbedAt Membership es x
w forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a (e :: SigClass). e f a -> FieldApp f a e
FieldApp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (h :: SigClass) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap f :-> g
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a (e :: SigClass). FieldApp f a e -> e f a
unFieldApp
                )
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (es :: [SigClass]) (f :: * -> *) a.
ExtensibleUnion es f a -> es :/ FieldApp f a
unExtensibleUnion
    {-# INLINE hfmap #-}

-- todo: Functor, Foldable, Traversable instances

instance Union ExtensibleUnion where
    type HasMembership _ e es = KnownNat (ClassIndex es e)

    inject :: forall (e :: SigClass) (es :: [SigClass]) (f :: * -> *).
HasMembership ExtensibleUnion e es =>
e f ~> ExtensibleUnion es f
inject = forall (es :: [SigClass]) (f :: * -> *) a.
(es :/ FieldApp f a) -> ExtensibleUnion es f a
ExtensibleUnion forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (xs :: [k]) (x :: k) (h :: k -> *).
Membership xs x -> h x -> xs :/ h
EmbedAt forall (xs :: [SigClass]) (x :: SigClass).
KnownNat (ClassIndex xs x) =>
Membership xs x
findFirstMembership forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a (e :: SigClass). e f a -> FieldApp f a e
FieldApp
    {-# INLINE inject #-}

    project :: forall (e :: SigClass) (es :: [SigClass]) (f :: * -> *) a.
HasMembership ExtensibleUnion e es =>
ExtensibleUnion es f a -> Maybe (e f a)
project (ExtensibleUnion es :/ FieldApp f a
u) = forall (f :: * -> *) a (e :: SigClass). FieldApp f a e -> e f a
unFieldApp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (h :: k -> *) (x :: k) (xs :: [k]).
Membership xs x -> (xs :/ h) -> Maybe (h x)
strikeAt forall (xs :: [SigClass]) (x :: SigClass).
KnownNat (ClassIndex xs x) =>
Membership xs x
findFirstMembership es :/ FieldApp f a
u
    {-# INLINE project #-}

    exhaust :: forall (f :: * -> *) a x. ExtensibleUnion '[] f a -> x
exhaust = forall {k} (h :: k -> *) r. ('[] :/ h) -> r
E.exhaust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (es :: [SigClass]) (f :: * -> *) a.
ExtensibleUnion es f a -> es :/ FieldApp f a
unExtensibleUnion
    {-# INLINE exhaust #-}

    inject0 :: forall (e :: SigClass) (f :: * -> *) (es :: [SigClass]).
e f ~> ExtensibleUnion (e : es) f
inject0 = forall (es :: [SigClass]) (f :: * -> *) a.
(es :/ FieldApp f a) -> ExtensibleUnion es f a
ExtensibleUnion forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (xs :: [k]) (x :: k) (h :: k -> *).
Membership xs x -> h x -> xs :/ h
EmbedAt forall {k} (x :: k) (xs :: [k]). Membership (x : xs) x
leadership forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a (e :: SigClass). e f a -> FieldApp f a e
FieldApp
    {-# INLINE inject0 #-}

    weaken :: forall (es :: [SigClass]) (f :: * -> *) (e :: SigClass).
ExtensibleUnion es f ~> ExtensibleUnion (e : es) f
weaken (ExtensibleUnion (EmbedAt Membership es x
w FieldApp f x x
e)) =
        forall (es :: [SigClass]) (f :: * -> *) a.
(es :/ FieldApp f a) -> ExtensibleUnion es f a
ExtensibleUnion forall a b. (a -> b) -> a -> b
$ forall {k} (xs :: [k]) (x :: k) (h :: k -> *).
Membership xs x -> h x -> xs :/ h
EmbedAt (forall {k} (xs :: [k]) (y :: k) (x :: k).
Membership xs y -> Membership (x : xs) y
nextMembership Membership es x
w) FieldApp f x x
e
    {-# INLINE weaken #-}

    e f a -> r
f |+: :: forall (e :: SigClass) (f :: * -> *) a r (es :: [SigClass]).
(e f a -> r)
-> (ExtensibleUnion es f a -> r)
-> ExtensibleUnion (e : es) f a
-> r
|+: ExtensibleUnion es f a -> r
g = (e f a -> r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a (e :: SigClass). FieldApp f a e -> e f a
unFieldApp forall {k} (h :: k -> *) (x :: k) r (xs :: [k]).
(h x -> r) -> ((xs :/ h) -> r) -> ((x : xs) :/ h) -> r
<:| ExtensibleUnion es f a -> r
g forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (es :: [SigClass]) (f :: * -> *) a.
(es :/ FieldApp f a) -> ExtensibleUnion es f a
ExtensibleUnion) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (es :: [SigClass]) (f :: * -> *) a.
ExtensibleUnion es f a -> es :/ FieldApp f a
unExtensibleUnion
    {-# INLINE (|+:) #-}

findFirstMembership :: forall xs x. KnownNat (ClassIndex xs x) => Membership xs x
findFirstMembership :: forall (xs :: [SigClass]) (x :: SigClass).
KnownNat (ClassIndex xs x) =>
Membership xs x
findFirstMembership = forall (pos :: Nat). Proxy pos -> KnownNat pos => Membership xs x
unsafeMkMembership @(ClassIndex xs x) forall {k} (t :: k). Proxy t
Proxy
  where
    -- This hack may break if the membership package version gets updated.
    unsafeMkMembership :: forall pos. Proxy pos -> KnownNat pos => Membership xs x
    unsafeMkMembership :: forall (pos :: Nat). Proxy pos -> KnownNat pos => Membership xs x
unsafeMkMembership Proxy pos
_ = case Elaborate x (FindType x xs) :~: 'Expecting pos
hackedEquality of Elaborate x (FindType x xs) :~: 'Expecting pos
Refl -> forall {k} (xs :: [k]) (x :: k). Member xs x => Membership xs x
membership
      where
        hackedEquality :: Elaborate x (FindType x xs) :~: 'Expecting pos
        hackedEquality :: Elaborate x (FindType x xs) :~: 'Expecting pos
hackedEquality = forall a b. a -> b
unsafeCoerce forall {k} (a :: k). a :~: a
Refl

instance HFunctorUnion_ (Forall HFunctor) ExtensibleUnion where
    type ForallHFunctor _ = Forall HFunctor

type e <| es = U.Member ExtensibleUnion e es
type e <<| es = U.MemberH ExtensibleUnion e es

type MemberBy key e efs = U.MemberBy ExtensibleUnion key e efs
type MemberHBy key e ehs = U.MemberHBy ExtensibleUnion key e ehs

infix 3 <|
infix 3 <<|

type ForallHFunctor = Forall HFunctor

type U ef = Union.U ExtensibleUnion ef
type UH eh = Union.UH ExtensibleUnion eh

type S ef = Union.S ExtensibleUnion ef
type SH eh = Union.SH ExtensibleUnion eh

type Eff fr eh ef = E.Eff ExtensibleUnion fr eh ef
type EffF fr es = E.EffF ExtensibleUnion fr es