module Scyther.Equalities (
TIDEq
, TIDRoleEq
, RoleEq
, AgentEqRHS
, AgentEq
, AVarEq
, MVarEq
, MsgEq
, AnyEq(..)
, agentEqToMsgEq
, mvarEqToMsgEq
, Equalities
, empty
, solve
, trimTIDEqs
, trimAgentEqs
, getTIDEqs
, getTIDRoleEqs
, getAgentEqs
, getAVarEqs
, getMVarEqs
, getPostEqs
, toAnyEqs
, anyEqTIDs
, substTID
, substLocalId
, substAVar
, substMVar
, substAgentId
, substAgentEqRHS
, substMsg
, substAnyEq
, threadRole
, maxMappedTID
, maxMappedAgentId
, reflexive
, null
, Mapping(..)
, emptyMapping
, mkMapping
, addTIDMapping
, addAgentIdMapping
, addTIDRoleMapping
, deleteTIDMapping
, deleteAgentIdMapping
, sptAnyEq
) where
import Prelude hiding (null)
import qualified Data.Map as M
import qualified Data.UnionFind as U
import Data.Data
import Control.Arrow ( (***) )
import Control.Monad
import Text.Isar
import Scyther.Protocol
import Scyther.Message
type TIDEq = (TID, TID)
type TIDEqs = M.Map TID TID
type TIDRoleEq = (TID, Role)
type TIDRoleEqs = M.Map TID Role
type RoleEq = (Role, Role)
type AgentEqRHS = Either AgentId AVar
type AgentEq = (AgentId, AgentEqRHS)
type AgentEqs = M.Map AgentId AgentEqRHS
type AVarEq = (AVar, AVar)
type AVarEqs = M.Map AVar AVar
type MVarEq = (MVar, Message)
type MVarEqs = M.Map MVar Message
type MsgEq = (Message, Message)
type MsgEqs = U.UnionFind Message
data AnyEq =
TIDEq !TIDEq
| TIDRoleEq !TIDRoleEq
| RoleEq !RoleEq
| AgentEq !AgentEq
| AVarEq !AVarEq
| MVarEq !MVarEq
| MsgEq !MsgEq
deriving( Eq, Ord, Show, Data, Typeable )
data Equalities = Equalities {
tidEqs :: TIDEqs
, roleEqs :: TIDRoleEqs
, avarEqs :: AVarEqs
, mvarEqs :: MVarEqs
, agntEqs :: AgentEqs
, postEqs :: MsgEqs
}
deriving( Eq, Ord, Show, Data, Typeable )
empty :: Equalities
empty = Equalities M.empty M.empty M.empty M.empty M.empty U.empty
null :: Equalities -> Bool
null = (==) empty
substTID :: Equalities -> TID -> TID
substTID eqs tid = M.findWithDefault tid tid (tidEqs eqs)
substLocalId :: Equalities -> LocalId -> LocalId
substLocalId eqs (LocalId (i, tid)) = (LocalId (i, substTID eqs tid))
substAVar :: Equalities -> AVar -> AVar
substAVar eqs av = M.findWithDefault av' av' (avarEqs eqs)
where av' = mapAVar (substLocalId eqs) av
substMVar :: Equalities -> MVar -> Message
substMVar eqs mv = M.findWithDefault (MMVar mv') mv' (mvarEqs eqs)
where mv' = mapMVar (substLocalId eqs) mv
substAgentId :: Equalities -> AgentId -> AgentEqRHS
substAgentId eqs aid = M.findWithDefault (Left aid) aid (agntEqs eqs)
substAgentEqRHS :: Equalities -> AgentEqRHS -> AgentEqRHS
substAgentEqRHS eqs = either (substAgentId eqs) (Right . substAVar eqs)
substMsg :: Equalities -> Message -> Message
substMsg eqs = normMsg . go
where
go m@(MConst _) = m
go (MFresh fr) = MFresh (mapFresh (substLocalId eqs) fr)
go (MAVar av) = MAVar (substAVar eqs av)
go (MMVar mv) = substMVar eqs mv
go (MAgent aid) = either MAgent MAVar (substAgentId eqs aid)
go (MHash m) = MHash (go m)
go (MTup m1 m2) = MTup (go m1) (go m2)
go (MEnc m1 m2) = MEnc (go m1) (go m2)
go (MSymK m1 m2) = MSymK (go m1) (go m2)
go (MShrK m1 m2) =
U.findWithDefault m' m' $ postEqs eqs
where
m' = MShrK (go m1) (go m2)
go (MAsymPK m) = MAsymPK (go m)
go (MAsymSK m) = MAsymSK (go m)
go (MInvKey m) = MInvKey (go m)
substTIDEq :: Equalities -> TIDEq -> TIDEq
substTIDEq eqs = substTID eqs *** substTID eqs
substTIDRoleEq :: Equalities -> TIDRoleEq -> AnyEq
substTIDRoleEq eqs (tid, role) = case threadRole tid' eqs of
Just role' -> RoleEq (role, role')
Nothing -> TIDRoleEq (tid', role)
where
tid' = substTID eqs tid
substAgentEq :: Equalities -> AgentEq -> MsgEq
substAgentEq eqs =
substMsgEq eqs . agentEqToMsgEq
substAVarEq :: Equalities -> AVarEq -> AVarEq
substAVarEq eqs = substAVar eqs *** substAVar eqs
substMVarEq :: Equalities -> MVarEq -> MsgEq
substMVarEq eqs = substMVar eqs *** substMsg eqs
substMsgEq :: Equalities -> MsgEq -> MsgEq
substMsgEq eqs = substMsg eqs *** substMsg eqs
substAnyEq :: Equalities -> AnyEq -> AnyEq
substAnyEq eqs eq0 = case eq0 of
TIDEq eq -> TIDEq $ substTIDEq eqs eq
TIDRoleEq eq -> substTIDRoleEq eqs eq
RoleEq _ -> eq0
AgentEq eq -> MsgEq $ substAgentEq eqs eq
AVarEq eq -> AVarEq $ substAVarEq eqs eq
MVarEq eq -> MsgEq $ substMVarEq eqs eq
MsgEq eq -> MsgEq $ substMsgEq eqs eq
agentEqToMsgEq :: AgentEq -> MsgEq
agentEqToMsgEq (aid, rhs) = (MAgent aid, either MAgent MAVar rhs)
mvarEqToMsgEq :: MVarEq -> MsgEq
mvarEqToMsgEq (v, m) = (MMVar v, m)
reflexive :: AnyEq -> Bool
reflexive eq0 = case eq0 of
TIDEq eq -> uncurry (==) eq
TIDRoleEq _ -> False
RoleEq eq -> uncurry (==) eq
AgentEq eq -> reflexive . MsgEq $ agentEqToMsgEq eq
AVarEq eq -> uncurry (==) eq
MVarEq eq -> reflexive . MsgEq $ mvarEqToMsgEq eq
MsgEq eq -> uncurry (==) eq
getTIDEqs :: Equalities -> [TIDEq]
getTIDEqs = M.toList . tidEqs
getTIDRoleEqs :: Equalities -> [TIDRoleEq]
getTIDRoleEqs = M.toList . roleEqs
getAgentEqs :: Equalities -> [AgentEq]
getAgentEqs = M.toList . agntEqs
getAVarEqs :: Equalities -> [AVarEq]
getAVarEqs = M.toList . avarEqs
getMVarEqs :: Equalities -> [MVarEq]
getMVarEqs = M.toList . mvarEqs
getPostEqs :: Equalities -> [MsgEq]
getPostEqs = U.toList . postEqs
toLists :: Equalities -> ([TIDEq], [TIDRoleEq], [AgentEq], [AVarEq], [MVarEq], [MsgEq])
toLists eqs =
(getTIDEqs eqs, getTIDRoleEqs eqs, getAgentEqs eqs
, getAVarEqs eqs, getMVarEqs eqs, getPostEqs eqs)
toAnyEqs :: Equalities -> [AnyEq]
toAnyEqs eqs =
map TIDEq a ++ map TIDRoleEq b ++ map AgentEq c ++ map AVarEq d ++
map MVarEq e ++ map MsgEq f
where (a, b, c, d, e, f) = toLists eqs
anyEqTIDs :: AnyEq -> [TID]
anyEqTIDs eq = case eq of
TIDEq (tid, _) -> return tid
TIDRoleEq (tid, _) -> return tid
RoleEq (_, _) -> mzero
AgentEq (_, rhs) -> either (const mzero) (return . avarTID) rhs
AVarEq (a1, a2) -> return (avarTID a1) `mplus` return (avarTID a2)
MVarEq (v, m) -> return (mvarTID v) `mplus` msgTIDs m
MsgEq (m1, m2) -> msgTIDs m1 `mplus` msgTIDs m2
normPostEqs :: Equalities -> Equalities
normPostEqs eqs0 =
eqs { postEqs = U.map (substMsg eqs) (postEqs eqs0) }
where
eqs = eqs0 { postEqs = U.empty }
solve :: Monad m => [AnyEq] -> Equalities -> m Equalities
solve ueqs eqs =
fst `liftM` solveRepeated ueqs eqs False
solveRepeated :: Monad m => [AnyEq] -> Equalities -> Bool -> m (Equalities, Bool)
solveRepeated [] eqs False = return (eqs, False)
solveRepeated [] eqs True =
solveRepeated (map MsgEq $ getPostEqs eqs) eqs False
solveRepeated (ueq:ueqs) eqs improved = do
(ueqs', eqs', improved') <- solve1 ueq eqs
solveRepeated (ueqs ++ ueqs') (normPostEqs eqs') (improved || improved')
solve1 :: Monad m => AnyEq -> Equalities -> m ([AnyEq], Equalities, Bool)
solve1 ueq eqs@(Equalities tideqs roleeqs aveqs mveqs agnteqs posteqs) =
case ueq of
TIDEq (tid1, tid2) ->
let tid1' = substTID eqs tid1
tid2' = substTID eqs tid2
elimTID x y = return
( mkAnyEqs TIDRoleEq roleeqs ++ mkAnyEqs AgentEq agnteqs ++
mkAnyEqs AVarEq aveqs ++ mkAnyEqs MVarEq mveqs ++
map MsgEq (U.toList posteqs)
, empty { tidEqs = M.insert x y tideqs }
, True
)
where
mkAnyEqs :: ((k, v) -> AnyEq) -> M.Map k v -> [AnyEq]
mkAnyEqs constr = map constr . M.toList
in
elimVarEqVar elimTID (tid1', tid1') (tid2', tid2')
TIDRoleEq (tid, role) ->
let tid' = substTID eqs tid
in
case M.lookup tid' roleeqs of
Just role' | role' /= role -> different "role" role role'
_ ->
updateSolution (eqs { roleEqs = M.insert tid' role roleeqs })
RoleEq (role1, role2)
| role1 == role2 -> skipEq
| otherwise -> different "role" role1 role2
AVarEq (av1, av2) ->
let av1' = substAVar eqs av1
av2' = substAVar eqs av2
elimAVar x y = updateSolution (eqs {
mvarEqs = M.map (substMsg elimEqs) mveqs
, agntEqs = M.map (substAgentEqRHS elimEqs) agnteqs
, avarEqs = M.insert x y $ M.map (substAVar elimEqs) aveqs
})
where elimEqs = empty { avarEqs = M.singleton x y }
in
elimVarEqVar elimAVar (av1', av1') (av2', av2')
AgentEq (lhs, rhs) ->
let elimAgentId x y = updateSolution (eqs {
mvarEqs = M.map (substMsg elimEqs) mveqs
, agntEqs = M.insert x y $ M.map (substAgentEqRHS elimEqs) agnteqs
})
where elimEqs = empty { agntEqs = M.singleton x y }
in
case (substAgentId eqs lhs, substAgentEqRHS eqs rhs) of
(lhs'@(Left aid1), rhs'@(Left aid2)) ->
elimVarEqVar elimAgentId (aid1, lhs') (aid2, rhs')
(lhs'@(Right _ ), (Left aid2)) -> elimAgentId aid2 lhs'
( (Left aid1), rhs'@(Right _ )) -> elimAgentId aid1 rhs'
( (Right av1), (Right av2)) -> newEqs [AVarEq (av1, av2)]
MVarEq (lhs, rhs) ->
let elimMVar x y
| x `elem` msgFMV y =
noUnifier $ "occurs check failed for '"++show x++"' in '"++show y++"'"
| otherwise =
updateSolution (eqs {
mvarEqs = M.insert x y $ M.map (substMsg elimEqs) mveqs
})
where elimEqs = empty { mvarEqs = M.singleton x y }
in
case (substMVar eqs lhs, substMsg eqs rhs) of
(lhs'@(MMVar mv1), rhs'@(MMVar mv2)) ->
elimVarEqVar elimMVar (mv1, lhs') (mv2, rhs')
(lhs' , (MMVar mv2)) -> elimMVar mv2 lhs'
( (MMVar mv1), rhs' ) -> elimMVar mv1 rhs'
(lhs' , rhs' ) -> newEqs [MsgEq (lhs', rhs')]
MsgEq eq -> case eq of
(MMVar mv1, rhs) -> newEqs [MVarEq (mv1, rhs)]
(lhs, MMVar mv2) -> newEqs [MVarEq (mv2, lhs)]
(MInvKey x, MInvKey y ) -> newEqs [MsgEq (x, y)]
(MInvKey x, MAsymPK m1) -> newEqs [MsgEq (x, MAsymSK m1)]
(MAsymPK m1, MInvKey x ) -> newEqs [MsgEq (x, MAsymSK m1)]
(MInvKey x, MAsymSK m1) -> newEqs [MsgEq (x, MAsymPK m1)]
(MAsymSK m1, MInvKey x ) -> newEqs [MsgEq (x, MAsymPK m1)]
(m1, MInvKey x ) -> newEqs [MsgEq (x, m1)]
(MInvKey x, m1 ) -> newEqs [MsgEq (x, m1)]
(MAgent aid1, MAgent aid2) -> newEqs [AgentEq (aid1, Left aid2)]
(MAgent aid1, MAVar av2 ) -> newEqs [AgentEq (aid1, Right av2)]
(MAVar av1, MAgent aid2) -> newEqs [AgentEq (aid2, Right av1)]
(MAVar av1, MAVar av2) -> newEqs [AVarEq (av1, av2)]
(MFresh (Fresh fr1), MFresh (Fresh fr2))
| lidId fr1 == lidId fr2 -> newEqs [TIDEq (lidTID fr1, lidTID fr2)]
| otherwise -> different "nonce" fr1 fr2
(MHash m1, MHash m2 ) -> newEqs [MsgEq (m1, m2)]
(MTup m11 m12, MTup m21 m22 ) -> newEqs [MsgEq (m11, m21), MsgEq (m12, m22)]
(MEnc m11 m12, MEnc m21 m22 ) -> newEqs [MsgEq (m11, m21), MsgEq (m12, m22)]
(MAsymPK m1, MAsymPK m2 ) -> newEqs [MsgEq (m1, m2)]
(MAsymSK m1, MAsymSK m2 ) -> newEqs [MsgEq (m1, m2)]
(MSymK m11 m12, MSymK m21 m22) -> newEqs [MsgEq (m11, m21), MsgEq (m12, m22)]
(m1@(MShrK m11 m12), m2@(MShrK m21 m22))
| m11 == m21 -> newEqs [MsgEq (m12, m22)]
| m11 == m22 -> newEqs [MsgEq (m12, m21)]
| m12 == m21 -> newEqs [MsgEq (m11, m22)]
| m12 == m22 -> newEqs [MsgEq (m11, m21)]
| m11 == m12 -> newEqs [MsgEq (m11, m21), MsgEq (m11, m22)]
| m21 == m22 -> newEqs [MsgEq (m11, m21), MsgEq (m12, m21)]
| (m1, m2) `U.equiv` posteqs -> skipEq
| otherwise ->
return ([], eqs { postEqs = U.equate m1 m2 $ posteqs }, False)
(MConst c1, MConst c2)
| c1 == c2 -> skipEq
| otherwise -> different "constant" c1 c2
(m1, m2) -> different "message" m1 m2
where
skipEq = return ([], eqs , False)
newEqs ueqs = return (ueqs, eqs , False)
updateSolution eqs' = return ([], eqs', True)
noUnifier = fail . ("solve1: " ++)
different ty x y = noUnifier $ ty ++ " '" ++ show x ++ "' /= '" ++ show y ++ "'"
elimVarEqVar elim (vl, lhs) (vr, rhs) =
case compare vl vr of
EQ -> skipEq
LT -> elim vr lhs
GT -> elim vl rhs
trimTIDEqs :: Equalities -> ([TID], Equalities)
trimTIDEqs eqs = (M.keys . tidEqs $ eqs, eqs { tidEqs = M.empty })
trimAgentEqs :: Equalities -> ([AgentId], Equalities)
trimAgentEqs eqs = (M.keys . agntEqs $ eqs, eqs { agntEqs = M.empty })
maxMappedTID :: Equalities -> Maybe TID
maxMappedTID = fmap (fst . fst) . M.maxViewWithKey . tidEqs
maxMappedAgentId :: Equalities -> Maybe AgentId
maxMappedAgentId = fmap (fst . fst) . M.maxViewWithKey . agntEqs
threadRole :: TID -> Equalities -> Maybe Role
threadRole tid eqs = M.lookup (substTID eqs tid) $ roleEqs eqs
newtype Mapping = Mapping { getMappingEqs :: Equalities }
deriving( Eq, Ord, Show, Data, Typeable )
mapMapping :: (Equalities -> Equalities) -> Mapping -> Mapping
mapMapping f = Mapping . f . getMappingEqs
emptyMapping :: Mapping
emptyMapping = Mapping empty
mkMapping :: M.Map TID TID -> M.Map AgentId AgentId -> Mapping
mkMapping tideqs agnteqs = Mapping $
empty {tidEqs = tideqs , agntEqs = M.map Left agnteqs}
addTIDMapping :: TID -> TID -> Mapping -> Mapping
addTIDMapping from to = mapMapping $ \eqs ->
eqs { tidEqs = M.insert from to $ tidEqs eqs }
addAgentIdMapping :: AgentId -> AgentId -> Mapping -> Mapping
addAgentIdMapping from to = mapMapping $ \eqs ->
eqs { agntEqs = M.insert from (Left to) $ agntEqs eqs }
addTIDRoleMapping :: TID -> Role -> Mapping -> Mapping
addTIDRoleMapping tid role = mapMapping $ \eqs ->
let tid' = substTID eqs tid
in eqs { roleEqs = M.insert tid' role $ roleEqs eqs }
deleteTIDMapping :: TID -> Mapping -> Mapping
deleteTIDMapping tid = mapMapping $ \eqs ->
eqs { tidEqs = M.delete tid $ tidEqs eqs }
deleteAgentIdMapping :: AgentId -> Mapping -> Mapping
deleteAgentIdMapping aid = mapMapping $ \eqs ->
eqs { agntEqs = M.delete aid $ agntEqs eqs }
ppEq :: (a -> Doc) -> (b -> Doc) -> (a, b) -> Doc
ppEq pp1 pp2 (x1, x2) = pp1 x1 <-> char '=' <-> pp2 x2
ppEq' :: (a -> Doc) -> (a, a) -> Doc
ppEq' pp = ppEq pp pp
instance Isar AnyEq where
isar conf eq0 = case eq0 of
TIDEq eq -> ppEq' ppIsar eq
RoleEq eq -> ppEq' (text . roleName) eq
TIDRoleEq (tid, role) ->
text "roleMap r" <-> ppIsar tid <-> text ("= Some " ++ roleName role)
AgentEq eq -> ppEq ppIsar (either ppIsar ppIsar) eq
AVarEq eq -> ppEq' ppIsar eq
MVarEq eq -> ppEq ppIsar ppIsar eq
MsgEq eq -> ppEq' ppIsar eq
where
ppIsar :: Isar a => a -> Doc
ppIsar = isar conf
sptAnyEq :: AnyEq -> Doc
sptAnyEq eq0 = case eq0 of
TIDEq eq -> ppEq' sptTID eq
RoleEq eq -> ppEq' (text . roleName) eq
TIDRoleEq (tid, role) ->
text "role(" <-> sptTID tid <-> text (") = " ++ roleName role)
AgentEq eq -> ppEq sptAgentId (either sptAgentId sptAVar) eq
AVarEq eq -> ppEq' sptAVar eq
MVarEq eq -> ppEq sptMVar sptMessage eq
MsgEq eq -> ppEq' sptMessage eq