{-# LANGUAGE Arrows, DisambiguateRecordFields, RecordWildCards, TypeSynonymInstances, CPP #-}

-- | This module defines a relational view of the network
-- state and configuration, and provides signal functions
-- that dynamically maintain this view.
module Nettle.FRPControl.NetInfo
    (
    -- * Switch and port information
    SwitchTable,
    SwitchRecord(..),
    PortTable,
    PortRecord(..),
    HasDataPathID(..),
    NetInfo, 
    portTable, 
    switchTable,
    NetworkMonitorPolicy(..),
    defaultMonitorPolicy,
    networkInfoRequester,
    switchInfo,
    
    -- * Host information
    HostDirectionMap,
    hostDirectionsSF,
    hostDirectionsChangeSF,
    HostLocationMap,
    hostLocationSF,

    -- * Port Statistics
    withPortStats,
    portRatesMapSF,
    nAveragePortRateMap    
    ) where

import Nettle.OpenFlow.Messages hiding (Features)
import Nettle.OpenFlow.Switch hiding (SwitchFeatures(..))
import qualified Nettle.OpenFlow.Switch as M
import qualified Nettle.OpenFlow.Port as P
import Nettle.OpenFlow.Port hiding (Port,portID)
import Nettle.OpenFlow.Action
import Nettle.OpenFlow.Packet 
import Nettle.OpenFlow.Statistics hiding (StatsReply(..))
import Nettle.Ethernet.EthernetAddress
import Nettle.IPv4.IPPacket
import Nettle.FRPControl.NettleSF 
import Nettle.Discovery.Topology 
import Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Monoid 
  

-- | A @SwitchTable@ is a list of @SwitchRecord@s
type SwitchTable  = [ SwitchRecord ]
data SwitchRecord 
    = SwitchRecord 
      { switchID           :: SwitchID           -- ^ switch identifier
      , packetBufferSize   :: Integer            -- ^ maximum number of packets buffered at the switch
      , numberFlowTables   :: Integer            -- ^ number of flow tables
      , capabilities       :: [SwitchCapability] -- ^ switch's capabilities
      , supportedActions   :: [ActionType]       -- ^ actions supported by the switch
      } deriving (Show,Eq,Ord)

-- | A list of @PortRecord@s; @PortRecords@ should be uniquely identifiable by their @SwitchID@ and @PortID@.
type PortTable    = [ PortRecord ]

data PortRecord = 
  PortRecord { portSwitch           :: SwitchID              -- ^ Switch the port belongs to
             , portID               :: PortID                -- ^ Port ID of the port
             , portAddr             :: EthernetAddress       -- ^ Hardware (Ethernet) address of the port.
             , spanningTreeState    :: SpanningTreePortState -- ^ Spanning tree protocol state for this port
             , isPortDown           :: Bool                  
             , isLinkDown           :: Bool 
             , isUsedForFlooding    :: Bool
             } deriving (Show, Eq)



-- | Type class for records having datapath-id fields. Having records
-- implement this class allows the client to use uniform syntax for fields
-- of different records.
class HasDataPathID a where
    dPID :: a -> SwitchID

instance HasDataPathID SwitchID where
    dPID = id

instance HasDataPathID PortRecord where
    dPID = portSwitch


-- | The NetInfo datatype carries network information in a form that
-- is easily retrieved from switches. 
newtype NetInfo = NetInfo (Map SwitchID M.SwitchFeatures) deriving Show

-- | Project the current @SwitchTable@ value from a @NetInfo@ value.
switchTable :: NetInfo -> SwitchTable
switchTable (NetInfo smap) = 
    [ SwitchRecord { switchID = id, 
                     packetBufferSize = packetBufferSize,
                     numberFlowTables = numberFlowTables, 
                     capabilities     = capabilities,
                     supportedActions = supportedActions
                   }
      | (id, M.SwitchFeatures {..}) <- Map.assocs smap ]

-- | Project the current @PortTable@ value from a @NetInfo@ value.
portTable :: NetInfo -> PortTable
portTable (NetInfo smap)
                  = concat $ 
                    Map.elems $ 
                    Map.map f smap
    where f sfr = [ PortRecord { portSwitch           = M.switchID sfr, 
                                 portID               = portNumber, 
                                 portAddr             = portAddr, 
                                 spanningTreeState    = portState, 
                                 isPortDown           = PortDown `elem` portConfig, 
                                 isLinkDown           = linkDown,
                                 isUsedForFlooding    = not (NoFlooding `elem` portConfig)
                               }
                    | P.Port portNumber portName portAddr portConfig linkDown portState _ _ _ _ <- M.ports sfr, 
                      portNumber <= maxNumberPorts
                  ]


-- | Map giving the next hop @PortID@ (if it is known) to a host from
-- a given switch.
type HostDirectionMap = Map (SwitchID, EthernetAddress) PortID

-- | Current @HostDirectionMap@
hostDirectionsSF :: SF (Event (SwitchID, SwitchMessage)) HostDirectionMap
hostDirectionsSF = hostDirectionsChangeSF >>> arr (liftE snd) >>> hold Map.empty

-- | Outputs an event whenever the host direction information changes. 
-- Outputs both the previous and the updated host direction map.
hostDirectionsChangeSF :: SF (Event (SwitchID, SwitchMessage)) (Event (HostDirectionMap, HostDirectionMap))
hostDirectionsChangeSF = arr packetInE >>> accumFilter learn Map.empty
    where learn dict (sid, pktRecord) = 
            case packetInFrame pktRecord of 
              Left msg    -> (dict, Nothing)
              Right frame -> 
                let src   = sourceAddress frame
                    port' = receivedOnPort pktRecord
                    dict' = Map.insert (sid,src) port' dict
                in case Map.lookup (sid, src) dict of 
                  Nothing   -> (dict', Just (dict, dict'))
                  Just port -> if port == port' 
                               then (dict, Nothing) 
                               else (dict', Just (dict,dict'))


type HostLocationMap = Map EthernetAddress (SwitchID, PortID)

hostLocationSF :: SF (Event (SwitchID, SwitchMessage), Topology) HostLocationMap
hostLocationSF = proc (i, topology) -> do 
  hold Map.empty <<< accumBy learn Map.empty -< packetInE i `attach` topology
    where learn hlMap ((dpid, pktInfo), topology) = 
            if (dpid, inPort) `portInTopology` topology
            then hlMap
            else case packetInFrame pktInfo of 
              Left str -> hlMap
              Right frame -> Map.insert (sourceAddress frame) (dpid, inPort) hlMap
            where inPort = receivedOnPort pktInfo
                        


portInTopology :: Port -> Topology -> Bool
portInTopology port = not . Set.null . Set.filter (port `Set.member`) 



                                                   
data NetworkMonitorPolicy 
    = NetworkMonitorPolicy { switchFeaturesRefreshPeriod :: Time -- ^ Amount of time (in seconds) between switch feature queries
                           , portStatisticsRefreshPeriod :: Time -- ^ Amount of time (in seconds) between port statistics queries
                           } deriving (Show,Eq)


defaultMonitorPolicy :: NetworkMonitorPolicy
defaultMonitorPolicy = 
  NetworkMonitorPolicy { switchFeaturesRefreshPeriod = 10  -- seconds
                       , portStatisticsRefreshPeriod = 5   -- seconds
                       } 

-- | Issues switch queries according to the given @NetworkMonitorPolicy@.
networkInfoRequester :: NetworkMonitorPolicy -> 
                        SF (Event (SwitchID, SwitchMessage)) (Event SwitchCommand)
networkInfoRequester policy = 
    proc i -> do 
      dpids <- activeSwitches -< i 
      cmdE <- switchFeatureMonitor (switchFeaturesRefreshPeriod policy) -< (i, dpids)
      cmdE' <- requestPortFlows (portStatisticsRefreshPeriod policy) -< (i, dpids)
      returnA -< mergeBy (<+>) cmdE cmdE'

switchFeatureMonitor :: Time -> SF (i, Set SwitchID) (Event SwitchCommand)
switchFeatureMonitor refreshPeriod = 
    proc (evt,dpids) -> do
      timeOut <- repeatedly refreshPeriod () -< ()
      returnA -< tag timeOut (mconcat [ requestFeatures swid | swid <- Set.toList dpids ])

activeSwitches :: SF (Event (SwitchID, SwitchMessage)) (Set SwitchID)
activeSwitches = proc i -> do 
  hold Set.empty <<< accum Set.empty -< (liftE (\(dpid,_) -> Set.insert dpid) (arrivalE i) `lMerge`
                                         liftE (Set.delete . fst) (departureE i)
                                        )

switchInfo :: SF (Event (SwitchID, SwitchMessage)) NetInfo
switchInfo = proc i -> do 
  let update = liftE (\(sw, sfr) -> Map.insert sw sfr) (arrivalE i) `lMerge`
               liftE (\(sw, _)   -> Map.delete sw) (departureE i) `lMerge`
               liftE (\(sw, sfr) -> Map.insert sw sfr) (featureUpdateE i)
  arr NetInfo <<< hold Map.empty <<< accum Map.empty -< update

requestPortFlows :: Time -> SF (Event (SwitchID, SwitchMessage), Set SwitchID) (Event SwitchCommand)
#if OPENFLOW_VERSION==151 || OPENFLOW_VERSION==152
requestPortFlows refreshPeriod = 
    proc (i, dpids) -> do 
      clock <- repeatedly refreshPeriod () -< ()
      let periodicQuery = tag clock ( mconcat [requestStats dpid PortStatsRequest | dpid <- Set.toList dpids])
      let joinQuery     = liftE (\(dpid,_) -> requestStats dpid PortStatsRequest) (arrivalE i)
      writeToSwitch -< mergeBy (<+>) joinQuery periodicQuery
#endif
#if OPENFLOW_VERSION==1
requestPortFlows refreshPeriod = 
    proc (i, dpids) -> do 
      clock <- repeatedly refreshPeriod () -< ()
      let periodicQuery = tag clock ( mconcat [requestStats dpid (PortStatsRequest AllPorts) | dpid <- Set.toList dpids])
      let joinQuery     = liftE (\(dpid,_) -> requestStats dpid (PortStatsRequest AllPorts)) (arrivalE i)
      returnA -< mergeBy (<+>) joinQuery periodicQuery
#endif


-- | Applies the given signal function to each switch-port pair in the network.
withPortStats :: SF (Event PortStats) a -> SF (Event (SwitchID, SwitchMessage)) (Map (SwitchID, PortID) a)
withPortStats sf = proc i -> do 
  let inserts = liftE (\(swid,sfr) -> Map.union (newMap swid sfr)) (arrivalE i)
  let deletes = liftE (\(swid,_)   -> Map.filterWithKey (\(swid',_) _ -> swid /= swid')) (departureE i)
  rpSwitchB Map.empty -< (i, inserts `lMerge` deletes)
  where newMap swid sfr = Map.fromList [ ((swid, pid), statsForPort swid pid >>> sf)
                                         | p <- M.ports sfr, let pid = P.portID p, pid <= maxNumberPorts ]


statsForPort :: SwitchID -> PortID -> SF (Event (SwitchID, SwitchMessage)) (Event PortStats) 
statsForPort dpid pid = 
    proc i -> do 
      returnA -< mapFilterE f (portStatReplyE i)
      where f (swid, ports) 
              | swid /= dpid = Nothing
              | swid == dpid = lookup pid ports


-- | Computes the rate of change of the port statistics vector by 
-- calculating - for each component of the vector - the difference between the last two samples
-- and dividing by the time difference. This measurements are then held until the next
-- sample is observed. 
portRatesMapSF :: SF (Event (SwitchID, SwitchMessage)) (Map (SwitchID, PortID) PortStats)
portRatesMapSF = withPortStats (oneStepDifferenceSF >>> hold nullPortStats)

averageRateMap :: Int -> SF (Event (SwitchID, SwitchMessage)) (Map (SwitchID, PortID) PortStats)
averageRateMap n = withPortStats (averageRateN n >>> hold nullPortStats)

-- | Tracks the average rate of change of the port statistics using a moving average
-- of the last @n@ port statistics updates, where @n@ is the first argument to 
-- the function.
averageRateN :: Int -> SF (Event PortStats) (Event PortStats)
averageRateN n = 
    proc aEvent -> do 
      t <- time -< ()
      accumFilter f [] -< aEvent `attach` t

    where f xs (a,t) = (xs', y)
              where y = if length xs' == n 
                        then Just (liftIntoPortStats1 (/dt) vdiff)
                        else Nothing
                    xs' = (a,t) : take (n-1) xs
                    vdiff = liftIntoPortStats2 (-) vmax vmin
                    (vmax,tmax) = head xs'
                    (vmin,tmin) = last xs'
                    dt = tmax - tmin
                    
nEvents :: Int -> SF (Event a) (Event [(Time,a)])
nEvents n = proc e -> do 
              t <- time -< ()
              accum [] -< liftE (\a -> take n . ((t,a):)) e

nAverage :: Int -> SF (Event PortStats) (Event PortStats)
nAverage n = proc e -> do
               arr (mapFilterE average) <<< nEvents n -< e
    where average [] = Nothing
          average ((t1,v1):tvs) 
              | null tvs  = Nothing
              | otherwise = let (tn,vn) = last tvs
                                dt      = t1 - tn
                                vdiff   = liftIntoPortStats2 (-) v1 vn
                                slope   = liftIntoPortStats1 (/dt) vdiff
                            in Just slope

-- | Tracks the average rate of change of the port statistics using a moving average
-- of the last @n@ port statistics updates, where @n@ is the first argument to 
-- the function.
nAveragePortRateMap :: Int -> SF (Event (SwitchID, SwitchMessage)) (Map (SwitchID, PortID) PortStats)
nAveragePortRateMap n = withPortStats (nAverage n >>> hold zeroPortStats)

smoothPortRateMap ::SF (Event (SwitchID, SwitchMessage)) (Map (SwitchID, PortID) PortStats)
smoothPortRateMap = withPortStats portRatesSmooth

portRatesSmooth :: SF (Event PortStats) PortStats
portRatesSmooth = oneStepDifferenceSF >>> oneStepDifferenceSF >>> hold zeroPortStats >>> integralPortStats

integralPortStats :: SF PortStats PortStats
integralPortStats = proc v -> do 
                      receivedPackets' <- integral -< maybe 0 id (portStatsReceivedPackets v)
                      sentPackets'     <- integral -< maybe 0 id (portStatsSentPackets v)
                      receivedBytes'   <- integral -< maybe 0 id (portStatsReceivedBytes v)
                      sentBytes'       <- integral -< maybe 0 id (portStatsSentBytes v)
                      receiverDropped' <- integral -< maybe 0 id (portStatsReceiverDropped v)
                      senderDropped'   <- integral -< maybe 0 id (portStatsSenderDropped v)
                      receiveErrors'   <- integral -< maybe 0 id (portStatsReceiveErrors v)
                      transmitErrors'  <- integral -< maybe 0 id (portStatsTransmitError v)
                      receivedFrameErrors'  <- integral -< maybe 0 id (portStatsReceivedFrameErrors v)
                      receiverOverrunError' <- integral -< maybe 0 id (portStatsReceiverOverrunError v)
                      receiverCRCError'     <- integral -< maybe 0 id (portStatsReceiverCRCError v)
                      collisions'           <- integral -< maybe 0 id (portStatsCollisions v)
                      let v' = PortStats { portStatsReceivedPackets      = Just receivedPackets', 
                                           portStatsSentPackets          = Just sentPackets', 
                                           portStatsReceivedBytes        = Just receivedBytes', 
                                           portStatsSentBytes            = Just sentBytes', 
                                           portStatsReceiverDropped      = Just receiverDropped', 
                                           portStatsSenderDropped        = Just senderDropped', 
                                           portStatsReceiveErrors        = Just receiveErrors', 
                                           portStatsTransmitError        = Just transmitErrors', 
                                           portStatsReceivedFrameErrors  = Just receivedFrameErrors', 
                                           portStatsReceiverOverrunError = Just receiverOverrunError', 
                                           portStatsReceiverCRCError     = Just receiverCRCError', 
                                           portStatsCollisions           = Just collisions'
                                         }
                      returnA -< v'


oneStepDifferenceSF :: SF (Event PortStats) (Event PortStats)
oneStepDifferenceSF = 
    proc statEvent -> do 
      statPair <- consecutiveEvents -< statEvent 
      returnA -< liftE slope statPair
    where slope ((pr1, t1), (pr2, t2)) = 
              liftIntoPortStats2 (\a1 a2 -> (a2 - a1) / (t2 - t1)) pr1 pr2



consecutiveEvents :: SF (Event a) (Event ((a,Time),(a,Time)))
consecutiveEvents = proc aEvent -> do 
                      t <- time -< ()
                      accumFilter f Nothing -< aEvent `attach` t
    where f ma' (a,t) = (Just (a,t), mb) 
              where mb = case ma' of
                           Just (a',t') -> Just ((a',t'),(a,t))
                           Nothing      -> Nothing