{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-orphans #-}

module Language.Haskell.Liquid.Transforms.RefSplit (

        splitXRelatedRefs

        ) where

import Prelude hiding (error)

import Data.List (partition)
import Text.PrettyPrint.HughesPJ

import Language.Haskell.Liquid.Types
import Language.Haskell.Liquid.Types.PrettyPrint ()

import Language.Fixpoint.Types hiding (Predicate)
import Language.Fixpoint.Misc

splitXRelatedRefs :: Symbol -> SpecType -> (SpecType, SpecType)
splitXRelatedRefs :: Symbol -> SpecType -> (SpecType, SpecType)
splitXRelatedRefs Symbol
x SpecType
t = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
x SpecType
t



splitRType :: Symbol
           -> RType c tv (UReft Reft)
           -> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType :: forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f (RVar tv
a UReft Reft
r) = (forall c tv r. tv -> r -> RType c tv r
RVar tv
a UReft Reft
r1, forall c tv r. tv -> r -> RType c tv r
RVar tv
a UReft Reft
r2)
  where
        (UReft Reft
r1, UReft Reft
r2) = Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef Symbol
f UReft Reft
r
splitRType Symbol
f (RFun Symbol
x RFInfo
i RType c tv (UReft Reft)
tx RType c tv (UReft Reft)
t UReft Reft
r) = (forall c tv r.
Symbol
-> RFInfo -> RType c tv r -> RType c tv r -> r -> RType c tv r
RFun Symbol
x RFInfo
i RType c tv (UReft Reft)
tx1 RType c tv (UReft Reft)
t1 UReft Reft
r1, forall c tv r.
Symbol
-> RFInfo -> RType c tv r -> RType c tv r -> r -> RType c tv r
RFun Symbol
x RFInfo
i RType c tv (UReft Reft)
tx2 RType c tv (UReft Reft)
t2 UReft Reft
r2)
  where
        (RType c tv (UReft Reft)
tx1, RType c tv (UReft Reft)
tx2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
tx
        (RType c tv (UReft Reft)
t1,  RType c tv (UReft Reft)
t2)  = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
t
        (UReft Reft
r1,  UReft Reft
r2)  = Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef   Symbol
f UReft Reft
r
splitRType Symbol
f (RAllT RTVU c tv
v RType c tv (UReft Reft)
t UReft Reft
r) = (forall c tv r. RTVU c tv -> RType c tv r -> r -> RType c tv r
RAllT RTVU c tv
v RType c tv (UReft Reft)
t1 UReft Reft
r1, forall c tv r. RTVU c tv -> RType c tv r -> r -> RType c tv r
RAllT RTVU c tv
v RType c tv (UReft Reft)
t2 UReft Reft
r2)
  where
        (RType c tv (UReft Reft)
t1, RType c tv (UReft Reft)
t2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
t
        (UReft Reft
r1,  UReft Reft
r2)  = Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef   Symbol
f UReft Reft
r
splitRType Symbol
f (RAllP PVU c tv
p RType c tv (UReft Reft)
t) = (forall c tv r. PVU c tv -> RType c tv r -> RType c tv r
RAllP PVU c tv
p RType c tv (UReft Reft)
t1, forall c tv r. PVU c tv -> RType c tv r -> RType c tv r
RAllP PVU c tv
p RType c tv (UReft Reft)
t2)
  where
        (RType c tv (UReft Reft)
t1, RType c tv (UReft Reft)
t2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
t
splitRType Symbol
f (RApp c
c [RType c tv (UReft Reft)]
ts [RTProp c tv (UReft Reft)]
rs UReft Reft
r) = (forall c tv r.
c -> [RType c tv r] -> [RTProp c tv r] -> r -> RType c tv r
RApp c
c [RType c tv (UReft Reft)]
ts1 [RTProp c tv (UReft Reft)]
rs1 UReft Reft
r1, forall c tv r.
c -> [RType c tv r] -> [RTProp c tv r] -> r -> RType c tv r
RApp c
c [RType c tv (UReft Reft)]
ts2 [RTProp c tv (UReft Reft)]
rs2 UReft Reft
r2)
  where
        ([RType c tv (UReft Reft)]
ts1, [RType c tv (UReft Reft)]
ts2) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [RType c tv (UReft Reft)]
ts)
        ([RTProp c tv (UReft Reft)]
rs1, [RTProp c tv (UReft Reft)]
rs2) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall c tv.
Symbol
-> RTProp c tv (UReft Reft)
-> (RTProp c tv (UReft Reft), RTProp c tv (UReft Reft))
splitUReft Symbol
f forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [RTProp c tv (UReft Reft)]
rs)
        (UReft Reft
r1,  UReft Reft
r2)  = Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef Symbol
f UReft Reft
r
splitRType Symbol
f (RAllE Symbol
x RType c tv (UReft Reft)
tx RType c tv (UReft Reft)
t) = (forall c tv r.
Symbol -> RType c tv r -> RType c tv r -> RType c tv r
RAllE Symbol
x RType c tv (UReft Reft)
tx1 RType c tv (UReft Reft)
t1, forall c tv r.
Symbol -> RType c tv r -> RType c tv r -> RType c tv r
RAllE Symbol
x RType c tv (UReft Reft)
tx2 RType c tv (UReft Reft)
t2)
  where
        (RType c tv (UReft Reft)
tx1, RType c tv (UReft Reft)
tx2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
tx
        (RType c tv (UReft Reft)
t1, RType c tv (UReft Reft)
t2)   = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
t
splitRType Symbol
f (REx Symbol
x RType c tv (UReft Reft)
tx RType c tv (UReft Reft)
t) = (forall c tv r.
Symbol -> RType c tv r -> RType c tv r -> RType c tv r
REx Symbol
x RType c tv (UReft Reft)
tx1 RType c tv (UReft Reft)
t1, forall c tv r.
Symbol -> RType c tv r -> RType c tv r -> RType c tv r
REx Symbol
x RType c tv (UReft Reft)
tx2 RType c tv (UReft Reft)
t2)
  where
        (RType c tv (UReft Reft)
tx1, RType c tv (UReft Reft)
tx2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
tx
        (RType c tv (UReft Reft)
t1, RType c tv (UReft Reft)
t2)   = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
t
splitRType Symbol
_ (RExprArg Located Expr
e) = (forall c tv r. Located Expr -> RType c tv r
RExprArg Located Expr
e, forall c tv r. Located Expr -> RType c tv r
RExprArg Located Expr
e)
splitRType Symbol
f (RAppTy RType c tv (UReft Reft)
tx RType c tv (UReft Reft)
t UReft Reft
r) = (forall c tv r. RType c tv r -> RType c tv r -> r -> RType c tv r
RAppTy RType c tv (UReft Reft)
tx1 RType c tv (UReft Reft)
t1 UReft Reft
r1, forall c tv r. RType c tv r -> RType c tv r -> r -> RType c tv r
RAppTy RType c tv (UReft Reft)
tx2 RType c tv (UReft Reft)
t2 UReft Reft
r2)
  where
        (RType c tv (UReft Reft)
tx1, RType c tv (UReft Reft)
tx2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
tx
        (RType c tv (UReft Reft)
t1,  RType c tv (UReft Reft)
t2)  = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
t
        (UReft Reft
r1,  UReft Reft
r2)  = Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef   Symbol
f UReft Reft
r
splitRType Symbol
f (RRTy [(Symbol, RType c tv (UReft Reft))]
xs UReft Reft
r Oblig
o RType c tv (UReft Reft)
rt) = (forall c tv r.
[(Symbol, RType c tv r)]
-> r -> Oblig -> RType c tv r -> RType c tv r
RRTy [(Symbol, RType c tv (UReft Reft))]
xs1 UReft Reft
r1 Oblig
o RType c tv (UReft Reft)
rt1, forall c tv r.
[(Symbol, RType c tv r)]
-> r -> Oblig -> RType c tv r -> RType c tv r
RRTy [(Symbol, RType c tv (UReft Reft))]
xs2 UReft Reft
r2 Oblig
o RType c tv (UReft Reft)
rt2)
  where
        ([(Symbol, RType c tv (UReft Reft))]
xs1, [(Symbol, RType c tv (UReft Reft))]
xs2) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall {a} {c} {tv}.
(a, RType c tv (UReft Reft))
-> ((a, RType c tv (UReft Reft)), (a, RType c tv (UReft Reft)))
go forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Symbol, RType c tv (UReft Reft))]
xs)
        (UReft Reft
r1, UReft Reft
r2) = Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef   Symbol
f UReft Reft
r
        (RType c tv (UReft Reft)
rt1, RType c tv (UReft Reft)
rt2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
rt

        go :: (a, RType c tv (UReft Reft))
-> ((a, RType c tv (UReft Reft)), (a, RType c tv (UReft Reft)))
go (a
x, RType c tv (UReft Reft)
t) = let (RType c tv (UReft Reft)
t1, RType c tv (UReft Reft)
t2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
f RType c tv (UReft Reft)
t in ((a
x,RType c tv (UReft Reft)
t1), (a
x, RType c tv (UReft Reft)
t2))
splitRType Symbol
f (RHole UReft Reft
r) = (forall c tv r. r -> RType c tv r
RHole UReft Reft
r1, forall c tv r. r -> RType c tv r
RHole UReft Reft
r2)
  where
        (UReft Reft
r1, UReft Reft
r2) = Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef Symbol
f UReft Reft
r


splitUReft :: Symbol -> RTProp c tv (UReft Reft) -> (RTProp c tv (UReft Reft), RTProp c tv (UReft Reft))
splitUReft :: forall c tv.
Symbol
-> RTProp c tv (UReft Reft)
-> (RTProp c tv (UReft Reft), RTProp c tv (UReft Reft))
splitUReft Symbol
x (RProp [(Symbol, RType c tv ())]
xs (RHole UReft Reft
r)) = (forall τ t. [(Symbol, τ)] -> t -> Ref τ t
RProp [(Symbol, RType c tv ())]
xs (forall c tv r. r -> RType c tv r
RHole UReft Reft
r1), forall τ t. [(Symbol, τ)] -> t -> Ref τ t
RProp [(Symbol, RType c tv ())]
xs (forall c tv r. r -> RType c tv r
RHole UReft Reft
r2))
  where
        (UReft Reft
r1, UReft Reft
r2) = Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef Symbol
x UReft Reft
r
splitUReft Symbol
x (RProp [(Symbol, RType c tv ())]
xs RType c tv (UReft Reft)
t) = (forall τ t. [(Symbol, τ)] -> t -> Ref τ t
RProp [(Symbol, RType c tv ())]
xs RType c tv (UReft Reft)
t1, forall τ t. [(Symbol, τ)] -> t -> Ref τ t
RProp [(Symbol, RType c tv ())]
xs RType c tv (UReft Reft)
t2)
  where
        (RType c tv (UReft Reft)
t1, RType c tv (UReft Reft)
t2) = forall c tv.
Symbol
-> RType c tv (UReft Reft)
-> (RType c tv (UReft Reft), RType c tv (UReft Reft))
splitRType Symbol
x RType c tv (UReft Reft)
t

splitRef :: Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef :: Symbol -> UReft Reft -> (UReft Reft, UReft Reft)
splitRef Symbol
f (MkUReft Reft
r Predicate
p) = (forall r. r -> Predicate -> UReft r
MkUReft Reft
r1 Predicate
p1, forall r. r -> Predicate -> UReft r
MkUReft Reft
r2 Predicate
p2)
        where
                (Reft
r1, Reft
r2) = Symbol -> Reft -> (Reft, Reft)
splitReft Symbol
f Reft
r
                (Predicate
p1, Predicate
p2) = Symbol -> Predicate -> (Predicate, Predicate)
splitPred Symbol
f Predicate
p

splitReft :: Symbol -> Reft -> (Reft, Reft)
splitReft :: Symbol -> Reft -> (Reft, Reft)
splitReft Symbol
f (Reft (Symbol
v, Expr
xs)) = ((Symbol, Expr) -> Reft
Reft (Symbol
v, ListNE Expr -> Expr
pAnd ListNE Expr
xs1), (Symbol, Expr) -> Reft
Reft (Symbol
v, ListNE Expr -> Expr
pAnd ListNE Expr
xs2))
  where
    (ListNE Expr
xs1, ListNE Expr
xs2)       = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall a. IsFree a => Symbol -> a -> Bool
isFree Symbol
f) (Expr -> ListNE Expr
unPAnd Expr
xs)

    unPAnd :: Expr -> ListNE Expr
unPAnd (PAnd ListNE Expr
ps) = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Expr -> ListNE Expr
unPAnd ListNE Expr
ps
    unPAnd Expr
p         = [Expr
p]


splitPred :: Symbol -> Predicate -> (Predicate, Predicate)
splitPred :: Symbol -> Predicate -> (Predicate, Predicate)
splitPred Symbol
f (Pr [UsedPVar]
ps) = ([UsedPVar] -> Predicate
Pr [UsedPVar]
ps1, [UsedPVar] -> Predicate
Pr [UsedPVar]
ps2)
  where
    ([UsedPVar]
ps1, [UsedPVar]
ps2) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition forall {a}. PVar a -> Bool
g [UsedPVar]
ps
    g :: PVar a -> Bool
g PVar a
p = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. IsFree a => Symbol -> a -> Bool
isFree Symbol
f) (forall a b c. (a, b, c) -> c
thd3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall t. PVar t -> [(t, Symbol, Expr)]
pargs PVar a
p)


class IsFree a where
        isFree :: Symbol -> a -> Bool

instance (Subable x) => (IsFree x) where
        isFree :: Symbol -> x -> Bool
isFree Symbol
x x
p = Symbol
x forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a. Subable a => a -> [Symbol]
syms x
p

instance Show (UReft Reft) where
         show :: UReft Reft -> String
show = Doc -> String
render forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. PPrint a => a -> Doc
pprint