{-# LANGUAGE TemplateHaskell #-}
--------------------------------------------------------
-- |
-- Module      : Crypto.Noise.Internal.Handshake.Pattern
-- Maintainer  : John Galt <jgalt@centromere.net>
-- Stability   : experimental
-- Portability : POSIX
module Crypto.Noise.Internal.Handshake.Pattern where

import Control.Applicative.Free
import Control.Lens
import Data.ByteString (ByteString)
import Data.Semigroup (Semigroup(..))

data Token next
  = E   next
  | S   next
  | Ee  next
  | Es  next
  | Se  next
  | Ss  next
  | Psk next

type MessagePattern = Ap Token

e :: MessagePattern ()
e :: MessagePattern ()
e = Token () -> MessagePattern ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Token () -> MessagePattern ()) -> Token () -> MessagePattern ()
forall a b. (a -> b) -> a -> b
$ () -> Token ()
forall next. next -> Token next
E ()

s :: MessagePattern ()
s :: MessagePattern ()
s = Token () -> MessagePattern ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Token () -> MessagePattern ()) -> Token () -> MessagePattern ()
forall a b. (a -> b) -> a -> b
$ () -> Token ()
forall next. next -> Token next
S ()

ee :: MessagePattern ()
ee :: MessagePattern ()
ee = Token () -> MessagePattern ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Token () -> MessagePattern ()) -> Token () -> MessagePattern ()
forall a b. (a -> b) -> a -> b
$ () -> Token ()
forall next. next -> Token next
Ee ()

es :: MessagePattern ()
es :: MessagePattern ()
es = Token () -> MessagePattern ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Token () -> MessagePattern ()) -> Token () -> MessagePattern ()
forall a b. (a -> b) -> a -> b
$ () -> Token ()
forall next. next -> Token next
Es ()

se :: MessagePattern ()
se :: MessagePattern ()
se = Token () -> MessagePattern ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Token () -> MessagePattern ()) -> Token () -> MessagePattern ()
forall a b. (a -> b) -> a -> b
$ () -> Token ()
forall next. next -> Token next
Se ()

ss :: MessagePattern ()
ss :: MessagePattern ()
ss = Token () -> MessagePattern ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Token () -> MessagePattern ()) -> Token () -> MessagePattern ()
forall a b. (a -> b) -> a -> b
$ () -> Token ()
forall next. next -> Token next
Ss ()

psk :: MessagePattern ()
psk :: MessagePattern ()
psk = Token () -> MessagePattern ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Token () -> MessagePattern ()) -> Token () -> MessagePattern ()
forall a b. (a -> b) -> a -> b
$ () -> Token ()
forall next. next -> Token next
Psk ()

data Message next
  = PreInitiator (MessagePattern ()) next
  | PreResponder (MessagePattern ()) next
  | Initiator    (MessagePattern ()) next
  | Responder    (MessagePattern ()) next

type MessageSequence = Ap Message

preInitiator :: MessagePattern () -> MessageSequence ()
preInitiator :: MessagePattern () -> MessageSequence ()
preInitiator = Message () -> MessageSequence ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Message () -> MessageSequence ())
-> (MessagePattern () -> Message ())
-> MessagePattern ()
-> MessageSequence ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MessagePattern () -> () -> Message ())
-> () -> MessagePattern () -> Message ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip MessagePattern () -> () -> Message ()
forall next. MessagePattern () -> next -> Message next
PreInitiator ()

preResponder :: MessagePattern () -> MessageSequence ()
preResponder :: MessagePattern () -> MessageSequence ()
preResponder = Message () -> MessageSequence ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Message () -> MessageSequence ())
-> (MessagePattern () -> Message ())
-> MessagePattern ()
-> MessageSequence ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MessagePattern () -> () -> Message ())
-> () -> MessagePattern () -> Message ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip MessagePattern () -> () -> Message ()
forall next. MessagePattern () -> next -> Message next
PreResponder ()

initiator :: MessagePattern () -> MessageSequence ()
initiator :: MessagePattern () -> MessageSequence ()
initiator = Message () -> MessageSequence ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Message () -> MessageSequence ())
-> (MessagePattern () -> Message ())
-> MessagePattern ()
-> MessageSequence ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MessagePattern () -> () -> Message ())
-> () -> MessagePattern () -> Message ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip MessagePattern () -> () -> Message ()
forall next. MessagePattern () -> next -> Message next
Initiator ()

responder :: MessagePattern () -> MessageSequence ()
responder :: MessagePattern () -> MessageSequence ()
responder = Message () -> MessageSequence ()
forall (f :: * -> *) a. f a -> Ap f a
liftAp (Message () -> MessageSequence ())
-> (MessagePattern () -> Message ())
-> MessagePattern ()
-> MessageSequence ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MessagePattern () -> () -> Message ())
-> () -> MessagePattern () -> Message ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip MessagePattern () -> () -> Message ()
forall next. MessagePattern () -> next -> Message next
Responder ()

-- | This type represents a handshake pattern such as @Noise_IK@. A large set of
--   pre-defined patterns can be found in "Crypto.Noise.HandshakePatterns".
--   Expert users are encouraged to define their own custom patterns with care.
data HandshakePattern = HandshakePattern
  { HandshakePattern -> ByteString
_hpName    :: ByteString
  , HandshakePattern -> Bool
_hpPSKMode :: Bool
  , HandshakePattern -> MessageSequence ()
_hpMsgSeq  :: MessageSequence ()
  }

$(makeLenses ''HandshakePattern)

newtype HasPSK = HasPSK { HasPSK -> Bool
unPSK :: Bool }

-- | Constructs a 'HandshakePattern' given a protocol name (such as @XXpsk3@)
--   and raw pattern. Please see the README for information about creating your
--   own custom patterns.
handshakePattern :: ByteString
                 -> MessageSequence ()
                 -> HandshakePattern
handshakePattern :: ByteString -> MessageSequence () -> HandshakePattern
handshakePattern ByteString
protoName MessageSequence ()
ms = ByteString -> Bool -> MessageSequence () -> HandshakePattern
HandshakePattern ByteString
protoName Bool
hasPSK MessageSequence ()
ms
  where
    hasPSK :: Bool
hasPSK = HasPSK -> Bool
unPSK (HasPSK -> Bool) -> HasPSK -> Bool
forall a b. (a -> b) -> a -> b
$ (forall a. Message a -> HasPSK) -> MessageSequence () -> HasPSK
forall m (f :: * -> *) b.
Monoid m =>
(forall a. f a -> m) -> Ap f b -> m
runAp_ Message a -> HasPSK
forall a. Message a -> HasPSK
scanS MessageSequence ()
ms

    scanS :: Message next -> HasPSK
scanS (PreInitiator MessagePattern ()
_ next
_) = HasPSK
forall a. Monoid a => a
mempty
    scanS (PreResponder MessagePattern ()
_ next
_) = HasPSK
forall a. Monoid a => a
mempty
    scanS (Initiator   MessagePattern ()
mp next
_) = (forall a. Token a -> HasPSK) -> MessagePattern () -> HasPSK
forall m (f :: * -> *) b.
Monoid m =>
(forall a. f a -> m) -> Ap f b -> m
runAp_ Token a -> HasPSK
forall a. Token a -> HasPSK
scanP MessagePattern ()
mp
    scanS (Responder   MessagePattern ()
mp next
_) = (forall a. Token a -> HasPSK) -> MessagePattern () -> HasPSK
forall m (f :: * -> *) b.
Monoid m =>
(forall a. f a -> m) -> Ap f b -> m
runAp_ Token a -> HasPSK
forall a. Token a -> HasPSK
scanP MessagePattern ()
mp

    scanP :: Token next -> HasPSK
scanP (Psk next
_) = Bool -> HasPSK
HasPSK Bool
True
    scanP Token next
_       = HasPSK
forall a. Monoid a => a
mempty

instance Semigroup HasPSK where
  (HasPSK Bool
a) <> :: HasPSK -> HasPSK -> HasPSK
<> (HasPSK Bool
b) = Bool -> HasPSK
HasPSK (Bool -> HasPSK) -> Bool -> HasPSK
forall a b. (a -> b) -> a -> b
$ Bool
a Bool -> Bool -> Bool
|| Bool
b

instance Monoid HasPSK where
  mempty :: HasPSK
mempty  = Bool -> HasPSK
HasPSK Bool
False
  mappend :: HasPSK -> HasPSK -> HasPSK
mappend = HasPSK -> HasPSK -> HasPSK
forall a. Semigroup a => a -> a -> a
(<>)