module Language.Haskell.FreeTheorems.Intermediate (
Intermediate (..)
, interpret
, interpretM
, relationVariables
, specialise
, specialiseInverse
) where
import Control.Monad (liftM, liftM2, mapM)
import Control.Monad.Reader (ReaderT, ask, runReaderT, local)
import Control.Monad.State (State, get, put, runState)
import Control.Monad.Trans (lift)
import Data.Generics ( Typeable, Data, everywhere, everything, listify, mkT
, mkQ, extQ)
import qualified Data.Map as Map (Map, empty, lookup, insert, map)
import Language.Haskell.FreeTheorems.LanguageSubsets
import Language.Haskell.FreeTheorems.Syntax
import Language.Haskell.FreeTheorems.ValidSyntax
import Language.Haskell.FreeTheorems.Theorems
import Language.Haskell.FreeTheorems.Frontend.TypeExpressions
( substituteTypeVariables )
import Language.Haskell.FreeTheorems.NameStores
( relationNameStore, typeExpressionNameStore, functionNameStore1, functionNameStore2 )
maybeToMonad :: Monad m => Maybe a -> m a
maybeToMonad mb =
case mb of
Just x -> return x
Nothing -> fail "Data.Map.lookup: Key not found"
data Intermediate = Intermediate
{ intermediateName :: String
, intermediateSubset :: LanguageSubset
, intermediateRelation :: Relation
, functionVariableNames1 :: [String]
, functionVariableNames2 :: [String]
, signatureNames :: [String]
, interpretNameStore :: NameStore
}
interpret ::
[ValidDeclaration] -> LanguageSubset -> ValidSignature -> Maybe Intermediate
interpret vds l s =
let n = unpackIdent . signatureName . rawSignature $ s
ss = getSignatureNames (map rawDeclaration vds)
fs = n : ss
t = signatureType . rawSignature $ s
(i, rs) = runState (runReaderT (interpretM l t) Map.empty) (initialState fs)
r = Intermediate n l i (filter (`notElem` fs) functionNameStore1) (filter (`notElem` fs) functionNameStore2) ss rs
in case l of
SubsetWithSeq _ -> Just r
otherwise -> if containsStrictTypes vds s
then Nothing
else Just r
where
getSignatureNames = everything (++) ([] `mkQ` getSigName)
getSigName (Signature i _) = [unpackIdent i]
containsStrictTypes vds s =
let rs = rawSignature s
ns = everything (++) ([] `mkQ` getCons `extQ` getClasses) rs
ds = map (getDeclarationName . rawDeclaration)
(filter isStrictDeclaration vds)
isStrict n = n `elem` ds
in any isStrict ns
getCons c = case c of { Con n -> [n] ; otherwise -> [] }
getClasses (TC n) = [n]
interpretM ::
LanguageSubset
-> TypeExpression
-> ReaderT Environment (State NameStore) Relation
interpretM l t = case t of
TypeVar v -> maybeToMonad.Map.lookup v =<< ask
TypeCon c ts -> do
rs <- mapM (interpretM l) ts
ri <- mkRelationInfo l t
let basic rel = case rel of { RelBasic _ -> True ; otherwise -> False }
if all basic rs
then return (RelBasic (RelationInfo l t t))
else return (RelLift ri c rs)
TypeFun t1 t2 -> do
ri <- mkRelationInfo l t
liftM2 (RelFun ri) (interpretM l t1) (interpretM l t2)
TypeFunLab t1 t2 -> do
ri <- mkRelationInfo l t
liftM2 (RelFunLab ri) (interpretM l t1) (interpretM l t2)
TypeAbs v cs t' -> do
ri <- mkRelationInfo l t
(rv, t1, t2) <- lift newRelationVariable
let rvar = RelVar (RelationInfo l t1 t2) rv
r <- local (Map.insert v rvar) $ interpretM l t'
let res = relRes l ++ (if null cs then [] else [RespectsClasses cs])
return (RelAbs ri rv (t1,t2) res r)
TypeAbsLab v cs t' -> do
ri <- mkRelationInfo l t
(rv, t1, t2) <- lift newRelationVariable
let rvar = RelVar (RelationInfo l t1 t2) rv
r <- local (Map.insert v rvar) $ interpretM l t'
let res = (filter (/= BottomReflecting) (relRes l)) ++ (if null cs then [] else [RespectsClasses cs])
return (RelAbs ri rv (t1,t2) res r)
where
mkRelationInfo l t = do
env <- ask
let getLt = relationLeftType . relationInfo
let getRt = relationRightType . relationInfo
let lt = substituteTypeVariables (Map.map getLt env) t
let rt = substituteTypeVariables (Map.map getRt env) t
return (RelationInfo l lt rt)
relRes l = case l of
BasicSubset -> [ ]
SubsetWithFix EquationalTheorem -> [ Strict, Continuous ]
SubsetWithFix InequationalTheorem -> [ Strict, Continuous
, LeftClosed ]
SubsetWithSeq EquationalTheorem -> [ Strict, Continuous
, BottomReflecting ]
SubsetWithSeq InequationalTheorem -> [ Strict, Continuous, Total
, LeftClosed ]
type Environment = Map.Map TypeVariable Relation
type NameStore = ([String], [TypeExpression])
initialState :: [String] -> NameStore
initialState ns =
( relationNameStore
, map (TypeExp . TF . Ident) . filter (`notElem` ns)
$ typeExpressionNameStore )
newRelationVariable ::
State NameStore (RelationVariable, TypeExpression, TypeExpression)
newRelationVariable = do
(rvs, ts) <- get
let ([rv], rvs') = splitAt 1 rvs
let ([t1, t2], ts') = splitAt 2 ts
put (rvs', ts')
return (RVar rv, t1, t2)
relationVariables :: Intermediate -> [RelationVariable]
relationVariables (Intermediate _ _ rel _ _ _ _) = getRVar True rel
where
getRVar ok rel = case rel of
RelLift _ _ rs -> concatMap (getRVar ok) rs
RelFun _ r1 r2 -> getRVar (not ok) r1 ++ getRVar ok r2
RelFunLab _ r1 r2 -> getRVar (not ok) r1 ++ getRVar ok r2
RelAbs _ rv _ _ r -> (if ok then [rv] else []) ++ getRVar ok r
FunAbs _ _ _ _ r -> getRVar ok r
otherwise -> []
specialise :: Intermediate -> RelationVariable -> Intermediate
specialise ir rv = reduceLifts (replaceRelVar ir rv Left)
specialiseInverse :: Intermediate -> RelationVariable -> Intermediate
specialiseInverse ir rv =
case theoremType (intermediateSubset ir) of
EquationalTheorem -> ir
InequationalTheorem -> reduceLifts (replaceRelVar ir rv Right)
replaceRelVar ::
Intermediate -> RelationVariable
-> (TermVariable -> Either TermVariable TermVariable) -> Intermediate
replaceRelVar ir (RVar rv) leftOrRight =
let ([funName], fns) = splitAt 1 (functionVariableNames1 ir)
fv = leftOrRight . TVar $ funName
relation = intermediateRelation ir
in ir { intermediateRelation = everywhere (mkT $ replace rv fv) relation
, functionVariableNames1 = drop 1 (functionVariableNames1 ir)
}
where
replace rv fv rel = case rel of
RelVar ri (RVar r) ->
let tv = either (Left . TermVar) (Right . TermVar) fv
in if rv == r then FunVar ri tv else rel
RelAbs ri (RVar r) ts res rel' ->
let res'' = either (const funResL) (const funResR) fv
res' = if elem BottomReflecting res || elem Total res then res'' else filter (/= Total) res''
in if rv == r
then FunAbs ri fv ts (res' ++ (classConstraints res)) rel'
else rel
otherwise -> rel
funResL = case intermediateSubset ir of
BasicSubset -> [ ]
SubsetWithFix _ -> [ Strict ]
SubsetWithSeq _ -> [ Strict, Total ]
funResR = case intermediateSubset ir of
BasicSubset -> [ ]
SubsetWithFix _ -> [ ]
SubsetWithSeq _ -> [ Strict ]
classConstraints res = filter isCC res
where
isCC r = case r of { RespectsClasses _ -> True ; otherwise -> False }
reduceLifts :: Intermediate -> Intermediate
reduceLifts ir =
ir { intermediateRelation = re True (intermediateRelation ir) }
where
re ok rel = case rel of
RelLift ri con rs -> if ok
then reduce (RelLift ri con (map (re ok) rs))
else rel
RelFun ri r1 r2 -> RelFun ri (re (mk' (not ok) ri r1) r1)
(re (mk ok ri r2) r2)
RelFunLab ri r1 r2 -> RelFunLab ri (re (mk' (not ok) ri r1) r1)
(re (mk ok ri r2) r2)
RelAbs ri rv ts res r -> RelAbs ri rv ts res (re ok r)
FunAbs ri fv ts res r -> FunAbs ri fv ts res (re ok r)
otherwise -> rel
mk' ok ri r = case theoremType (relationLanguageSubset ri) of
EquationalTheorem -> True
InequationalTheorem ->
case r of
RelLift _ ConList _ -> True
otherwise -> ok
mk ok ri r = case theoremType (relationLanguageSubset ri) of
EquationalTheorem -> True
InequationalTheorem -> ok
reduce rel = case rel of
RelLift ri con rs -> maybe rel id (toTerm ri con rs)
otherwise -> rel
toTerm ri con rs = do
f <- funSymbol con
case mapM leftFun rs of
Just fts -> Just . FunVar ri . Left $ term f fts
Nothing ->
case mapM rightFun rs of
Just fts -> Just . FunVar ri . Right $ term f fts
Nothing -> Nothing
funSymbol con = case con of
ConList -> Just . TVar $ "map"
Con (Ident "Maybe") -> Just . TVar $ "fmap"
otherwise -> Nothing
leftFun rel = case rel of
FunVar ri (Left f) -> Just (f, ( relationLeftType ri
, relationRightType ri))
otherwise -> Nothing
rightFun rel = case rel of
FunVar ri (Right f) -> Just (f, ( relationRightType ri
, relationLeftType ri))
otherwise -> Nothing
term f fts =
let (fs, ts) = unzip fts
termins t (t1, t2) = TermIns (TermIns t t1) t2
in foldl TermApp (foldl termins (TermVar f) ts) fs