module Scyther.Message (
TID(..)
, LocalId(..)
, Fresh(..)
, AVar(..)
, MVar(..)
, AgentId(..)
, Message(..)
, lidId
, lidTID
, avarTID
, mvarTID
, msgFMV
, msgFresh
, msgAgentIds
, msgTIDs
, trivial
, submessages
, messageparts
, mapFresh
, mapAVar
, mapMVar
, inst
, normMsg
, splitNonTrivial
, sptTID
, sptAgentId
, sptFresh
, sptAVar
, sptMVar
, sptMessage
) where
import Control.Monad
import Control.Applicative
import Data.Data
import Data.Monoid
import qualified Data.Set as S
import Text.Isar
import Scyther.Protocol
newtype TID = TID { getTID :: Int }
deriving( Eq, Ord, Enum, Num, Data, Typeable )
instance Show TID where
show (TID tid) = '#':show tid
newtype AgentId = AgentId { agentId :: Int }
deriving( Eq, Ord, Enum, Num, Data, Typeable )
instance Show AgentId where
show (AgentId aid) = 'a':show aid
newtype LocalId = LocalId { getLocalId :: (Id, TID) }
deriving( Eq, Ord, Data, Typeable )
instance Show LocalId where
show (LocalId (i, tid)) = show i ++ show tid
newtype AVar = AVar { getAVar :: LocalId }
deriving( Eq, Ord, Show, Data, Typeable )
newtype MVar = MVar { getMVar :: LocalId }
deriving( Eq, Ord, Show, Data, Typeable )
newtype Fresh = Fresh { getFresh :: LocalId }
deriving( Eq, Ord, Show, Data, Typeable )
data Message =
MConst Id
| MFresh Fresh
| MAVar AVar
| MMVar MVar
| MAgent AgentId
| MHash Message
| MTup Message Message
| MEnc Message Message
| MSymK Message Message
| MShrK Message Message
| MAsymPK Message
| MAsymSK Message
| MInvKey Message
deriving( Eq, Ord, Show, Data, Typeable )
lidId :: LocalId -> Id
lidId = fst . getLocalId
lidTID :: LocalId -> TID
lidTID = snd . getLocalId
avarTID :: AVar -> TID
avarTID = snd . getLocalId . getAVar
mvarTID :: MVar -> TID
mvarTID = snd . getLocalId . getMVar
msgTIDs :: Message -> [TID]
msgTIDs (MConst _) = empty
msgTIDs (MFresh f) = pure . lidTID . getFresh $ f
msgTIDs (MAVar v) = pure . avarTID $ v
msgTIDs (MMVar v) = pure . mvarTID $ v
msgTIDs (MAgent _) = empty
msgTIDs (MHash m) = msgTIDs m
msgTIDs (MTup m1 m2) = msgTIDs m1 `mappend` msgTIDs m2
msgTIDs (MEnc m1 m2) = msgTIDs m1 `mappend` msgTIDs m2
msgTIDs (MSymK m1 m2) = msgTIDs m1 `mappend` msgTIDs m2
msgTIDs (MShrK m1 m2) = msgTIDs m1 `mappend` msgTIDs m2
msgTIDs (MAsymPK m) = msgTIDs m
msgTIDs (MAsymSK m) = msgTIDs m
msgTIDs (MInvKey m) = msgTIDs m
msgAgentIds :: Message -> [AgentId]
msgAgentIds (MConst _) = empty
msgAgentIds (MFresh _) = empty
msgAgentIds (MAVar _) = empty
msgAgentIds (MMVar _) = empty
msgAgentIds (MAgent a) = pure a
msgAgentIds (MHash m) = msgAgentIds m
msgAgentIds (MTup m1 m2) = msgAgentIds m1 `mappend` msgAgentIds m2
msgAgentIds (MEnc m1 m2) = msgAgentIds m1 `mappend` msgAgentIds m2
msgAgentIds (MSymK m1 m2) = msgAgentIds m1 `mappend` msgAgentIds m2
msgAgentIds (MShrK m1 m2) = msgAgentIds m1 `mappend` msgAgentIds m2
msgAgentIds (MAsymPK m) = msgAgentIds m
msgAgentIds (MAsymSK m) = msgAgentIds m
msgAgentIds (MInvKey m) = msgAgentIds m
msgFMV :: Message -> [MVar]
msgFMV (MMVar v) = pure v
msgFMV (MHash m) = msgFMV m
msgFMV (MTup m1 m2) = msgFMV m1 <|> msgFMV m2
msgFMV (MEnc m1 m2) = msgFMV m1 <|> msgFMV m2
msgFMV (MSymK m1 m2) = msgFMV m1 <|> msgFMV m2
msgFMV (MAsymPK m) = msgFMV m
msgFMV (MAsymSK m) = msgFMV m
msgFMV (MInvKey m) = msgFMV m
msgFMV _ = empty
msgFresh :: Message -> [Fresh]
msgFresh (MFresh lid) = pure lid
msgFresh (MHash m) = msgFresh m
msgFresh (MTup m1 m2) = msgFresh m1 <|> msgFresh m2
msgFresh (MEnc m1 m2) = msgFresh m1 <|> msgFresh m2
msgFresh (MSymK m1 m2) = msgFresh m1 <|> msgFresh m2
msgFresh (MAsymPK m) = msgFresh m
msgFresh (MAsymSK m) = msgFresh m
msgFresh (MInvKey m) = msgFresh m
msgFresh _ = empty
trivial :: Message -> Bool
trivial (MConst _) = True
trivial (MAVar _) = True
trivial (MTup _ _) = True
trivial (MAgent _) = True
trivial _ = False
submessages :: Message -> S.Set Message
submessages m@(MHash m1) = S.insert m $ submessages m1
submessages m@(MTup m1 m2) = S.insert m $ submessages m1 `S.union` submessages m2
submessages m@(MEnc m1 m2) = S.insert m $ submessages m1 `S.union` submessages m2
submessages m@(MSymK m1 m2) = S.insert m $ submessages m1 `S.union` submessages m2
submessages m@(MAsymPK m1) = S.insert m $ submessages m1
submessages m@(MAsymSK m1) = S.insert m $ submessages m1
submessages (MInvKey _) = error "submessages: undefined for key inversion"
submessages m = S.singleton m
messageparts :: Message -> S.Set Message
messageparts m@(MTup m1 m2) = S.insert m $ messageparts m1 `S.union` messageparts m2
messageparts m@(MEnc m1 m2) = S.insert m $ messageparts m1 `S.union` messageparts m2
messageparts m = S.singleton m
mapFresh :: (LocalId -> LocalId) -> Fresh -> Fresh
mapFresh f = Fresh . f . getFresh
mapAVar :: (LocalId -> LocalId) -> AVar -> AVar
mapAVar f = AVar . f . getAVar
mapMVar :: (LocalId -> LocalId) -> MVar -> MVar
mapMVar f = MVar . f . getMVar
inst :: TID -> Pattern -> Message
inst _ (PConst i) = MConst i
inst tid (PFresh i) = MFresh (Fresh (LocalId (i, tid)))
inst tid (PAVar i) = MAVar (AVar (LocalId (i, tid)))
inst tid (PMVar i) = MMVar (MVar (LocalId (i, tid)))
inst tid (PHash pt) = MHash (inst tid pt)
inst tid (PTup pt1 pt2) = MTup (inst tid pt1) (inst tid pt2)
inst tid (PEnc pt1 pt2) = MEnc (inst tid pt1) (inst tid pt2)
inst tid (PSign pt1 pt2) = MTup m1 (MEnc m1 (normMsg $ MInvKey (inst tid pt2)))
where m1 = inst tid pt1
inst tid (PSymK pt1 pt2) = MSymK (inst tid pt1) (inst tid pt2)
inst tid (PShrK pt1 pt2) = MShrK (inst tid pt1) (inst tid pt2)
inst tid (PAsymPK pt) = MAsymPK (inst tid pt)
inst tid (PAsymSK pt) = MAsymSK (inst tid pt)
normMsg :: Message -> Message
normMsg m@(MConst _) = m
normMsg m@(MFresh _) = m
normMsg m@(MAVar _) = m
normMsg m@(MMVar _) = m
normMsg m@(MAgent _) = m
normMsg (MInvKey (MInvKey m)) = normMsg m
normMsg (MInvKey (MAsymPK m)) = MAsymSK (normMsg m)
normMsg (MInvKey (MAsymSK m)) = MAsymPK (normMsg m)
normMsg m@(MInvKey (MMVar _)) = m
normMsg (MInvKey m) = normMsg m
normMsg (MHash m) = MHash (normMsg m)
normMsg (MTup m1 m2) = MTup (normMsg m1) (normMsg m2)
normMsg (MEnc m1 m2) = MEnc (normMsg m1) (normMsg m2)
normMsg (MSymK m1 m2) = MSymK (normMsg m1) (normMsg m2)
normMsg (MShrK m1 m2)
| m1' < m2' = MShrK m1' m2'
| otherwise = MShrK m2' m1'
where
m1' = normMsg m1
m2' = normMsg m2
normMsg (MAsymPK m) = MAsymPK (normMsg m)
normMsg (MAsymSK m) = MAsymSK (normMsg m)
splitNonTrivial :: Message -> [Message]
splitNonTrivial (MTup m1 m2) = splitNonTrivial m1 `mplus` splitNonTrivial m2
splitNonTrivial m = do
guard (not $ trivial m)
return m
esplSubst :: LocalId -> IsarConf -> Doc -> Doc
esplSubst (LocalId (_,tid)) conf var
| tid == 0 = isarSubst conf <> parens var
| otherwise = text "s" <> parens var
instance Isar TID where
isar _ tid = text "tid" <> int (getTID tid)
instance Isar AgentId where
isar _ aid = text "a" <> int (agentId aid)
instance Isar LocalId where
isar conf (LocalId (i, tid)) = isar conf i <-> isar conf tid
instance Isar Fresh where
isar conf (Fresh i) = text "LN" <-> isar conf i
instance Isar AVar where
isar conf (AVar i) = esplSubst i conf (text "AV" <-> isar conf i)
instance Isar MVar where
isar conf (MVar i) = esplSubst i conf (text "MV" <-> isar conf i)
instance Isar Message where
isar conf x = case x of
(MConst i) -> text "LC" <-> isar conf i
(MFresh i) -> isar conf i
(MAVar i) -> isar conf i
(MMVar i) -> isar conf i
(MAgent i) -> isar conf i
(MHash m) -> text "Hash" <-> ppTup m
pt@(MTup _ _) -> ppTup pt
(MEnc m k) -> text "Enc" <-> sep [ppTup m, ppTup k]
(MSymK a b) -> text "K" <-> sep [ppTup a, ppTup b]
(MShrK a b) -> text "Kbd" <-> sep [ppTup a, ppTup b]
(MAsymPK a) -> text "PK" <-> ppTup a
(MAsymSK a) -> text "SK" <-> ppTup a
(MInvKey m) -> text "inv" <> parens (isar conf m)
where
ppTup m@(MTup _ _) = nestShort n ldelim rdelim (fsep $ punctuate comma $ map (isar conf) $ split m)
ppTup m = nestShort' "(" ")" (isar conf m)
split (MTup m1 m2) = m1 : split m2
split m = [m]
(n,ldelim,rdelim)
| isPlainStyle conf = (3, text "{|", text "|}")
| otherwise = (2, symbol "\\<lbrace>", symbol "\\<rbrace>")
sptTID :: TID -> Doc
sptTID = text . show
sptAgentId :: AgentId -> Doc
sptAgentId = (char 'a' <>) . int . agentId
sptLocalId :: LocalId -> Doc
sptLocalId (LocalId (i, tid)) = sptId i <> sptTID tid
sptFresh :: Fresh -> Doc
sptFresh = (char '~' <>) . sptLocalId . getFresh
sptAVar :: AVar -> Doc
sptAVar = sptLocalId . getAVar
sptMVar :: MVar -> Doc
sptMVar = (char '?' <>) . sptLocalId . getMVar
sptMessage :: Message -> Doc
sptMessage x = case x of
(MConst i) -> char '\'' <> sptId i <> char '\''
(MFresh i) -> sptFresh i
(MAVar i) -> sptAVar i
(MAgent i) -> sptAgentId i
(MMVar i) -> sptMVar i
(MHash m) -> text "h" <> ppBetween 1 "(" ")" m
pt@(MTup _ _) -> ppBetween 1 "(" ")" pt
(MEnc m k) -> fcat [ppBetween 1 "{" "}" m, sptMessage k]
(MSymK a b) -> fcat [text "k(", sptMessage a, comma, sptMessage b, text ")"]
(MShrK a b) -> fcat [text "k[", sptMessage a, comma, sptMessage b, text "]"]
(MAsymPK a) -> text "pk" <> ppBetween 1 "(" ")" a
(MAsymSK a) -> text "sk" <> ppBetween 1 "(" ")" a
(MInvKey m) -> text "inv" <> ppBetween 1 "(" ")" m
where
ppBetween n lead finish m@(MTup _ _) =
fcat . (text lead :) . (++[text finish]) . map (nest n) . punctuate (text ", ") . map sptMessage $ split m
ppBetween _ lead finish m = text lead <> sptMessage m <> text finish
split (MTup m1 m2) = m1 : split m2
split m = [m]