{-# LANGUAGE CPP, DisambiguateRecordFields, RecordWildCards, NamedFieldPuns #-}

module Nettle.OpenFlow.Match ( 
  Match (..)
  , matchAny
  , isExactMatch
  , getExactMatch
  , frameToExactMatch
  , ofpVlanNone
  , matches
  ) where

import Nettle.Ethernet.EthernetAddress
import Nettle.Ethernet.EthernetFrame 
import Nettle.IPv4.IPAddress
import qualified Nettle.IPv4.IPPacket as IP
import Nettle.OpenFlow.Port
import Data.Maybe (isJust)
import Data.Binary
import Control.Monad.Error

-- | Each flow entry includes a match, which essentially defines packet-matching condition. 
-- Fields that are left Nothing are "wildcards".
data Match = Match { inPort                             :: Maybe PortID, 
                     srcEthAddress, dstEthAddress       :: Maybe EthernetAddress, 
                     vLANID                             :: Maybe VLANID, 
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1
                     vLANPriority                       :: Maybe VLANPriority, 
#endif
                     ethFrameType                       :: Maybe EthernetTypeCode,
#if OPENFLOW_VERSION==1
                     ipTypeOfService                    :: Maybe IP.IPTypeOfService, 
#endif
                     ipProtocol                         :: Maybe IP.IPProtocol, 
                     srcIPAddress, dstIPAddress         :: IPAddressPrefix,
                     srcTransportPort, dstTransportPort :: Maybe IP.TransportPort }
             deriving (Show,Read,Eq)


-- |A match that matches every packet.
matchAny :: Match
matchAny = Match { inPort           = Nothing, 
                   srcEthAddress    = Nothing, 
                   dstEthAddress    = Nothing, 
                   vLANID           = Nothing, 
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1
                   vLANPriority     = Nothing, 
#endif
                   ethFrameType     = Nothing, 
#if OPENFLOW_VERSION==1
                   ipTypeOfService  = Nothing, 
#endif
                   ipProtocol       = Nothing, 
                   srcIPAddress     = defaultIPPrefix,
                   dstIPAddress     = defaultIPPrefix, 
                   srcTransportPort = Nothing, 
                   dstTransportPort = Nothing }

-- | Return True if given 'Match' represents an exact match, i.e. no
--   wildcards and the IP addresses' prefixes cover all bits.
isExactMatch :: Match -> Bool
isExactMatch (Match {..}) =
    (isJust inPort) &&
    (isJust srcEthAddress) &&
    (isJust dstEthAddress) &&
    (isJust vLANID) &&
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1
    (isJust vLANPriority) &&
#endif
    (isJust ethFrameType) &&
#if OPENFLOW_VERSION==1
    (isJust ipTypeOfService) &&
#endif
    (isJust ipProtocol) &&
    (prefixIsExact srcIPAddress) &&
    (prefixIsExact dstIPAddress) &&
    (isJust srcTransportPort) &&
    (isJust dstTransportPort)

ofpVlanNone         = 0xffff

frameToExactMatch :: PortID -> EthernetFrame -> Match
frameToExactMatch inPort frame = 
  addEthConditions frame (matchAny { inPort = Just inPort })
  where addEthConditions (EthernetFrame ethHdr ethBody) m = 
          let m1 = case ethHdr of 
                Ethernet8021Q {..} -> 
                  m { srcEthAddress = Just sourceMACAddress
                    , dstEthAddress = Just destMACAddress
                    , vLANID        = Just vlanId
                    , ethFrameType  = Just typeCode
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1                                      
                    , vLANPriority  = Just priorityCodePoint
#endif
                    }
                EthernetHeader {..} -> 
                  m { srcEthAddress = Just sourceMACAddress
                    , dstEthAddress = Just destMACAddress
                    , ethFrameType  = Just typeCode
                    , vLANID        = Just (fromIntegral ofpVlanNone)
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1  
                    , vLANPriority  = Just 0
#endif
                    }
          in case ethBody of 
            IPInEthernet (IP.IPPacket (IP.IPHeader {..}) ipBody) -> 
              let m2 = m1 { ipProtocol       = Just ipProtocol
                          , srcIPAddress     = ipSrcAddress // 32 
                          , dstIPAddress     = ipDstAddress // 32 
#if OPENFLOW_VERSION==1
                          , ipTypeOfService  = Just dscp 
#endif
                          }
              in case ipBody of 
                    IP.TCPInIP (src,dst)            -> m2 { srcTransportPort = Just src,      
                                                            dstTransportPort = Just dst  }  
                    IP.UDPInIP (src,dst)            -> m2 { srcTransportPort = Just src,      
                                                            dstTransportPort = Just dst  }  
                    IP.ICMPInIP (icmpType,icmpCode) -> m2 { srcTransportPort = Just (fromIntegral icmpType), 
                                                            dstTransportPort = Just 0    }  
                    IP.UninterpretedIPBody _        -> m2
                      
            ARPInEthernet (ARPPacket {..}) -> 
              m1 { ipProtocol   = Just ( if arpOpCode == ARPRequest then 1 else 2)
                 , srcIPAddress = senderIPAddress // 32
                 , dstIPAddress = targetIPAddress // 32 
                 }
            UninterpretedEthernetBody _ -> m1


-- | Utility function to get an exact match corresponding to 
-- a packet (as given by a byte sequence).
getExactMatch :: PortID -> GetE Match
getExactMatch inPort = do
  frame <- getEthernetFrame
  return (frameToExactMatch inPort frame)


-- | Models the match semantics of an OpenFlow switch.
matches :: (PortID, EthernetFrame) -> Match -> Bool
#if OPENFLOW_VERSION==151 || OPENFLOW_VERSION==152
matches (inPort, frame@(EthernetFrame ethHeader ethBody)) (m@Match { inPort=inPort',..}) =     
#endif    
#if OPENFLOW_VERSION==1
matches (inPort, frame@(EthernetFrame ethHeader ethBody)) (m@Match { inPort=inPort', ipTypeOfService=ipTypeOfService',..}) = 
#endif
    and [maybe True matchesInPort           inPort', 
         maybe True matchesSrcEthAddress    srcEthAddress,
         maybe True matchesDstEthAddress    dstEthAddress, 
         maybe True matchesVLANID           vLANID, 
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1             
         maybe True matchesVLANPriority     vLANPriority,
#endif
         maybe True matchesEthFrameType     ethFrameType, 
         maybe True matchesIPProtocol       ipProtocol, 
#if OPENFLOW_VERSION==1
         maybe True matchesIPToS            ipTypeOfService',
#endif
         matchesIPSourcePrefix srcIPAddress,
         matchesIPDestPrefix dstIPAddress,
         maybe True matchesSrcTransportPort srcTransportPort, 
         maybe True matchesDstTransportPort dstTransportPort ]
        where
          matchesInPort p = p == inPort
          matchesSrcEthAddress a = IP.sourceAddress frame == a 
          matchesDstEthAddress a = IP.destAddress frame == a 
          matchesVLANID a = 
              case ethHeader of 
                EthernetHeader {} -> True
                Ethernet8021Q {..}-> a == vlanId
#if OPENFLOW_VERSION==152 || OPENFLOW_VERSION==1             
          matchesVLANPriority a = 
              case ethHeader of 
                EthernetHeader {} -> True
                Ethernet8021Q {..} -> a == priorityCodePoint
#endif                
          matchesEthFrameType  t = t == typeCode ethHeader
          matchesIPProtocol protCode = 
              case ethBody of 
                IPInEthernet (IP.IPPacket (IP.IPHeader {..}) ipBody) -> ipProtocol == protCode
                _ -> True
#if OPENFLOW_VERSION==1          
          matchesIPToS tos =
                case ethBody of 
                  IPInEthernet (IP.IPPacket (IP.IPHeader {..}) _) -> tos == dscp
                  _ -> True
#endif                  
          matchesIPSourcePrefix prefix = 
              case ethBody of 
                IPInEthernet ipPkt -> IP.sourceAddress ipPkt `elemOfPrefix` prefix
                _ -> True
          matchesIPDestPrefix prefix = 
              case ethBody of 
                IPInEthernet ipPkt -> IP.destAddress ipPkt `elemOfPrefix` prefix
                _ -> True
          matchesSrcTransportPort sp = 
              case ethBody of 
                IPInEthernet (IP.IPPacket ipHeader ipBody) -> 
                    case ipBody of 
                      IP.TCPInIP (srcPort, _) -> srcPort == sp
                      IP.UDPInIP (srcPort, _) -> srcPort == sp
                      _ -> True
                _ -> True
          matchesDstTransportPort dp = 
              case ethBody of 
                IPInEthernet (IP.IPPacket ipHeader ipBody) -> 
                    case ipBody of 
                      IP.TCPInIP (_, dstPort) -> dstPort == dp
                      IP.UDPInIP (_, dstPort) -> dstPort == dp
                      _ -> True
                _ -> True