{-# LANGUAGE CPP, DeriveFunctor, DeriveFoldable, DeriveTraversable, LambdaCase, RecordWildCards, ViewPatterns #-}
module TypeLevel.Rewrite.Internal.DecomposedConstraint where

import Control.Applicative

-- GHC API
import GHC (Class, Type)
#if MIN_VERSION_ghc(9,0,0)
import GHC.Tc.Types.Constraint (Ct, ctEvPred, ctEvidence)
import GHC.Core.Predicate (EqRel(NomEq), Pred(ClassPred, EqPred), classifyPredType, mkClassPred, mkPrimEqPred)
#else
import Constraint (Ct, ctEvPred, ctEvidence)
import Predicate (EqRel(NomEq), Pred(ClassPred, EqPred), classifyPredType, mkClassPred, mkPrimEqPred)
#endif


data DecomposedConstraint a
  = EqualityConstraint a a        -- lhs ~ rhs
  | InstanceConstraint Class [a]  -- C a b c
  deriving ((forall a b.
 (a -> b) -> DecomposedConstraint a -> DecomposedConstraint b)
-> (forall a b.
    a -> DecomposedConstraint b -> DecomposedConstraint a)
-> Functor DecomposedConstraint
forall a b. a -> DecomposedConstraint b -> DecomposedConstraint a
forall a b.
(a -> b) -> DecomposedConstraint a -> DecomposedConstraint b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DecomposedConstraint b -> DecomposedConstraint a
$c<$ :: forall a b. a -> DecomposedConstraint b -> DecomposedConstraint a
fmap :: forall a b.
(a -> b) -> DecomposedConstraint a -> DecomposedConstraint b
$cfmap :: forall a b.
(a -> b) -> DecomposedConstraint a -> DecomposedConstraint b
Functor, (forall m. Monoid m => DecomposedConstraint m -> m)
-> (forall m a.
    Monoid m =>
    (a -> m) -> DecomposedConstraint a -> m)
-> (forall m a.
    Monoid m =>
    (a -> m) -> DecomposedConstraint a -> m)
-> (forall a b. (a -> b -> b) -> b -> DecomposedConstraint a -> b)
-> (forall a b. (a -> b -> b) -> b -> DecomposedConstraint a -> b)
-> (forall b a. (b -> a -> b) -> b -> DecomposedConstraint a -> b)
-> (forall b a. (b -> a -> b) -> b -> DecomposedConstraint a -> b)
-> (forall a. (a -> a -> a) -> DecomposedConstraint a -> a)
-> (forall a. (a -> a -> a) -> DecomposedConstraint a -> a)
-> (forall a. DecomposedConstraint a -> [a])
-> (forall a. DecomposedConstraint a -> Bool)
-> (forall a. DecomposedConstraint a -> Int)
-> (forall a. Eq a => a -> DecomposedConstraint a -> Bool)
-> (forall a. Ord a => DecomposedConstraint a -> a)
-> (forall a. Ord a => DecomposedConstraint a -> a)
-> (forall a. Num a => DecomposedConstraint a -> a)
-> (forall a. Num a => DecomposedConstraint a -> a)
-> Foldable DecomposedConstraint
forall a. Eq a => a -> DecomposedConstraint a -> Bool
forall a. Num a => DecomposedConstraint a -> a
forall a. Ord a => DecomposedConstraint a -> a
forall m. Monoid m => DecomposedConstraint m -> m
forall a. DecomposedConstraint a -> Bool
forall a. DecomposedConstraint a -> Int
forall a. DecomposedConstraint a -> [a]
forall a. (a -> a -> a) -> DecomposedConstraint a -> a
forall m a. Monoid m => (a -> m) -> DecomposedConstraint a -> m
forall b a. (b -> a -> b) -> b -> DecomposedConstraint a -> b
forall a b. (a -> b -> b) -> b -> DecomposedConstraint a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => DecomposedConstraint a -> a
$cproduct :: forall a. Num a => DecomposedConstraint a -> a
sum :: forall a. Num a => DecomposedConstraint a -> a
$csum :: forall a. Num a => DecomposedConstraint a -> a
minimum :: forall a. Ord a => DecomposedConstraint a -> a
$cminimum :: forall a. Ord a => DecomposedConstraint a -> a
maximum :: forall a. Ord a => DecomposedConstraint a -> a
$cmaximum :: forall a. Ord a => DecomposedConstraint a -> a
elem :: forall a. Eq a => a -> DecomposedConstraint a -> Bool
$celem :: forall a. Eq a => a -> DecomposedConstraint a -> Bool
length :: forall a. DecomposedConstraint a -> Int
$clength :: forall a. DecomposedConstraint a -> Int
null :: forall a. DecomposedConstraint a -> Bool
$cnull :: forall a. DecomposedConstraint a -> Bool
toList :: forall a. DecomposedConstraint a -> [a]
$ctoList :: forall a. DecomposedConstraint a -> [a]
foldl1 :: forall a. (a -> a -> a) -> DecomposedConstraint a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> DecomposedConstraint a -> a
foldr1 :: forall a. (a -> a -> a) -> DecomposedConstraint a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> DecomposedConstraint a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> DecomposedConstraint a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> DecomposedConstraint a -> b
foldl :: forall b a. (b -> a -> b) -> b -> DecomposedConstraint a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> DecomposedConstraint a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> DecomposedConstraint a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> DecomposedConstraint a -> b
foldr :: forall a b. (a -> b -> b) -> b -> DecomposedConstraint a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> DecomposedConstraint a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> DecomposedConstraint a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> DecomposedConstraint a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> DecomposedConstraint a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> DecomposedConstraint a -> m
fold :: forall m. Monoid m => DecomposedConstraint m -> m
$cfold :: forall m. Monoid m => DecomposedConstraint m -> m
Foldable, Functor DecomposedConstraint
Foldable DecomposedConstraint
Functor DecomposedConstraint
-> Foldable DecomposedConstraint
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> DecomposedConstraint a -> f (DecomposedConstraint b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    DecomposedConstraint (f a) -> f (DecomposedConstraint a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> DecomposedConstraint a -> m (DecomposedConstraint b))
-> (forall (m :: * -> *) a.
    Monad m =>
    DecomposedConstraint (m a) -> m (DecomposedConstraint a))
-> Traversable DecomposedConstraint
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a.
Monad m =>
DecomposedConstraint (m a) -> m (DecomposedConstraint a)
forall (f :: * -> *) a.
Applicative f =>
DecomposedConstraint (f a) -> f (DecomposedConstraint a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DecomposedConstraint a -> m (DecomposedConstraint b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DecomposedConstraint a -> f (DecomposedConstraint b)
sequence :: forall (m :: * -> *) a.
Monad m =>
DecomposedConstraint (m a) -> m (DecomposedConstraint a)
$csequence :: forall (m :: * -> *) a.
Monad m =>
DecomposedConstraint (m a) -> m (DecomposedConstraint a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DecomposedConstraint a -> m (DecomposedConstraint b)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> DecomposedConstraint a -> m (DecomposedConstraint b)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
DecomposedConstraint (f a) -> f (DecomposedConstraint a)
$csequenceA :: forall (f :: * -> *) a.
Applicative f =>
DecomposedConstraint (f a) -> f (DecomposedConstraint a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DecomposedConstraint a -> f (DecomposedConstraint b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> DecomposedConstraint a -> f (DecomposedConstraint b)
Traversable)

asEqualityConstraint
  :: Ct
  -> Maybe (Type, Type)
asEqualityConstraint :: Ct -> Maybe (Type, Type)
asEqualityConstraint Ct
ct = do
  let predTree :: Pred
predTree
        = Type -> Pred
classifyPredType
        (Type -> Pred) -> Type -> Pred
forall a b. (a -> b) -> a -> b
$ CtEvidence -> Type
ctEvPred
        (CtEvidence -> Type) -> CtEvidence -> Type
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence
        (Ct -> CtEvidence) -> Ct -> CtEvidence
forall a b. (a -> b) -> a -> b
$ Ct
ct
  case Pred
predTree of
    EqPred EqRel
NomEq Type
lhs Type
rhs
      -> (Type, Type) -> Maybe (Type, Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
lhs, Type
rhs)
    Pred
_ -> Maybe (Type, Type)
forall a. Maybe a
Nothing

asInstanceConstraint
  :: Ct
  -> Maybe (Class, [Type])
asInstanceConstraint :: Ct -> Maybe (Class, [Type])
asInstanceConstraint Ct
ct = do
  let predTree :: Pred
predTree
        = Type -> Pred
classifyPredType
        (Type -> Pred) -> Type -> Pred
forall a b. (a -> b) -> a -> b
$ CtEvidence -> Type
ctEvPred
        (CtEvidence -> Type) -> CtEvidence -> Type
forall a b. (a -> b) -> a -> b
$ Ct -> CtEvidence
ctEvidence
        (Ct -> CtEvidence) -> Ct -> CtEvidence
forall a b. (a -> b) -> a -> b
$ Ct
ct
  case Pred
predTree of
    ClassPred Class
typeclass [Type]
args
      -> (Class, [Type]) -> Maybe (Class, [Type])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Class
typeclass, [Type]
args)
    Pred
_ -> Maybe (Class, [Type])
forall a. Maybe a
Nothing

asDecomposedConstraint
  :: Ct
  -> Maybe (DecomposedConstraint Type)
asDecomposedConstraint :: Ct -> Maybe (DecomposedConstraint Type)
asDecomposedConstraint Ct
ct
    = ((Type -> Type -> DecomposedConstraint Type)
-> (Type, Type) -> DecomposedConstraint Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Type -> Type -> DecomposedConstraint Type
forall a. a -> a -> DecomposedConstraint a
EqualityConstraint ((Type, Type) -> DecomposedConstraint Type)
-> Maybe (Type, Type) -> Maybe (DecomposedConstraint Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe (Type, Type)
asEqualityConstraint Ct
ct)
  Maybe (DecomposedConstraint Type)
-> Maybe (DecomposedConstraint Type)
-> Maybe (DecomposedConstraint Type)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> ((Class -> [Type] -> DecomposedConstraint Type)
-> (Class, [Type]) -> DecomposedConstraint Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Class -> [Type] -> DecomposedConstraint Type
forall a. Class -> [a] -> DecomposedConstraint a
InstanceConstraint ((Class, [Type]) -> DecomposedConstraint Type)
-> Maybe (Class, [Type]) -> Maybe (DecomposedConstraint Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ct -> Maybe (Class, [Type])
asInstanceConstraint Ct
ct)

fromDecomposeConstraint
  :: DecomposedConstraint Type
  -> Type
fromDecomposeConstraint :: DecomposedConstraint Type -> Type
fromDecomposeConstraint = \case
  EqualityConstraint Type
t1 Type
t2
    -> Type -> Type -> Type
mkPrimEqPred Type
t1 Type
t2
  InstanceConstraint Class
cls [Type]
args
    -> Class -> [Type] -> Type
mkClassPred Class
cls [Type]
args