-- SPDX-FileCopyrightText: 2022 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

{-# LANGUAGE QuantifiedConstraints, DeriveLift #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Morley.Util.Constrained
  ( Constrained(..)
  , NullConstraint
  , mapConstrained
  , traverseConstrained
  , withConstrained
  , withConstrainedM
  , foldConstrained
  , foldConstrainedM
  ) where

import Data.GADT.Compare (GCompare, GEq, defaultCompare, defaultEq)
import Fmt (Buildable(..))
import Language.Haskell.TH.Syntax (Lift)

-- | Always truthful unary constraint. Can be used to essentially turn
-- 'Constrained' into a somewhat inefficient @Some@.
type NullConstraint :: forall k. k -> Constraint
class NullConstraint any
instance NullConstraint any

type Constrained :: (k -> Constraint) -> (k -> Type) -> Type
data Constrained c f where
  Constrained :: forall c f a. c a => f a -> Constrained c f

-- | Map over argument.
mapConstrained :: (forall t. c t => f t -> g t) -> Constrained c f -> Constrained c g
mapConstrained :: forall {k} (c :: k -> Constraint) (f :: k -> *) (g :: k -> *).
(forall (t :: k). c t => f t -> g t)
-> Constrained c f -> Constrained c g
mapConstrained forall (t :: k). c t => f t -> g t
f = (forall (t :: k). c t => f t -> Constrained c g)
-> Constrained c f -> Constrained c g
forall {k} (c :: k -> Constraint) (f :: k -> *) r.
(forall (t :: k). c t => f t -> r) -> Constrained c f -> r
foldConstrained ((forall (t :: k). c t => f t -> Constrained c g)
 -> Constrained c f -> Constrained c g)
-> (forall (t :: k). c t => f t -> Constrained c g)
-> Constrained c f
-> Constrained c g
forall a b. (a -> b) -> a -> b
$ g t -> Constrained c g
forall {k} (c :: k -> Constraint) (f :: k -> *) (a :: k).
c a =>
f a -> Constrained c f
Constrained (g t -> Constrained c g) -> (f t -> g t) -> f t -> Constrained c g
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f t -> g t
forall (t :: k). c t => f t -> g t
f

-- | Traverse over argument.
traverseConstrained :: Functor m => (forall a. c a => f a -> m (g a)) -> Constrained c f -> m (Constrained c g)
traverseConstrained :: forall {k} (m :: * -> *) (c :: k -> Constraint) (f :: k -> *)
       (g :: k -> *).
Functor m =>
(forall (a :: k). c a => f a -> m (g a))
-> Constrained c f -> m (Constrained c g)
traverseConstrained forall (a :: k). c a => f a -> m (g a)
f = (forall (t :: k). c t => f t -> m (Constrained c g))
-> Constrained c f -> m (Constrained c g)
forall {k} (c :: k -> Constraint) (f :: k -> *) r.
(forall (t :: k). c t => f t -> r) -> Constrained c f -> r
foldConstrained ((forall (t :: k). c t => f t -> m (Constrained c g))
 -> Constrained c f -> m (Constrained c g))
-> (forall (t :: k). c t => f t -> m (Constrained c g))
-> Constrained c f
-> m (Constrained c g)
forall a b. (a -> b) -> a -> b
$ (g t -> Constrained c g) -> m (g t) -> m (Constrained c g)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap g t -> Constrained c g
forall {k} (c :: k -> Constraint) (f :: k -> *) (a :: k).
c a =>
f a -> Constrained c f
Constrained (m (g t) -> m (Constrained c g))
-> (f t -> m (g t)) -> f t -> m (Constrained c g)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f t -> m (g t)
forall (a :: k). c a => f a -> m (g a)
f

-- | Apply function to constrained value
withConstrained :: Constrained c f -> (forall t. c t => f t -> r) -> r
withConstrained :: forall {k} (c :: k -> Constraint) (f :: k -> *) r.
Constrained c f -> (forall (t :: k). c t => f t -> r) -> r
withConstrained (Constrained f a
x) forall (t :: k). c t => f t -> r
f = f a -> r
forall (t :: k). c t => f t -> r
f f a
x

-- | Monadic 'withConstrained'
withConstrainedM :: Monad m => m (Constrained c f) -> (forall t. f t -> m r) -> m r
withConstrainedM :: forall {k} (m :: * -> *) (c :: k -> Constraint) (f :: k -> *) r.
Monad m =>
m (Constrained c f) -> (forall (t :: k). f t -> m r) -> m r
withConstrainedM m (Constrained c f)
m forall (t :: k). f t -> m r
f = m (Constrained c f)
m m (Constrained c f) -> (Constrained c f -> m r) -> m r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall (t :: k). c t => f t -> m r) -> Constrained c f -> m r
forall {k} (c :: k -> Constraint) (f :: k -> *) r.
(forall (t :: k). c t => f t -> r) -> Constrained c f -> r
foldConstrained forall (t :: k). c t => f t -> m r
forall (t :: k). f t -> m r
f

-- | Flipped version of 'withConstrained'
foldConstrained :: (forall t. c t => f t -> r) -> Constrained c f -> r
foldConstrained :: forall {k} (c :: k -> Constraint) (f :: k -> *) r.
(forall (t :: k). c t => f t -> r) -> Constrained c f -> r
foldConstrained forall (t :: k). c t => f t -> r
f (Constrained f a
x) = f a -> r
forall (t :: k). c t => f t -> r
f f a
x

-- | Flipped version of 'withConstrainedM'
foldConstrainedM :: Monad m => (forall t. c t => f t -> m r) -> m (Constrained c f) -> m r
foldConstrainedM :: forall {k} (m :: * -> *) (c :: k -> Constraint) (f :: k -> *) r.
Monad m =>
(forall (t :: k). c t => f t -> m r) -> m (Constrained c f) -> m r
foldConstrainedM forall (t :: k). c t => f t -> m r
f m (Constrained c f)
m = m (Constrained c f)
m m (Constrained c f) -> (Constrained c f -> m r) -> m r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall (t :: k). c t => f t -> m r) -> Constrained c f -> m r
forall {k} (c :: k -> Constraint) (f :: k -> *) r.
(forall (t :: k). c t => f t -> r) -> Constrained c f -> r
foldConstrained forall (t :: k). c t => f t -> m r
f

deriving stock instance (forall a. c a => Show (f a)) => Show (Constrained c f)

instance (forall a. c a => NFData (f a)) => NFData (Constrained c f) where
  rnf :: Constrained c f -> ()
rnf (Constrained f a
x) = f a -> ()
forall a. NFData a => a -> ()
rnf f a
x

deriving stock instance (forall a. c a => Lift (f a)) => Lift (Constrained c f)

instance (forall a. c a => Buildable (f a)) => Buildable (Constrained c f) where
  build :: Constrained c f -> Builder
build (Constrained f a
a) = f a -> Builder
forall p. Buildable p => p -> Builder
build f a
a

instance GEq f => Eq (Constrained c f) where
  (Constrained f a
a) == :: Constrained c f -> Constrained c f -> Bool
== (Constrained f a
b) = f a -> f a -> Bool
forall {k} (f :: k -> *) (a :: k) (b :: k).
GEq f =>
f a -> f b -> Bool
defaultEq f a
a f a
b

instance GCompare f => Ord (Constrained c f) where
  compare :: Constrained c f -> Constrained c f -> Ordering
compare (Constrained f a
a) (Constrained f a
b) = f a -> f a -> Ordering
forall {k} (f :: k -> *) (a :: k) (b :: k).
GCompare f =>
f a -> f b -> Ordering
defaultCompare f a
a f a
b