{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}

module Hans.Layer.Arp (
    ArpHandle
  , runArpLayer

    -- External Interface
  , arpWhoHas
  , arpIP4Packet
  , addLocalAddress
  ) where

import Hans.Address.IP4 (IP4,parseIP4,renderIP4)
import Hans.Address.Mac (Mac,parseMac,renderMac,broadcastMac)
import Hans.Channel
import Hans.Layer
import Hans.Layer.Arp.Table
import Hans.Layer.Ethernet
import Hans.Message.Arp
    (ArpPacket(..),parseArpPacket,renderArpPacket,ArpOper(..))
import Hans.Message.EthernetFrame
import Hans.Timers (delay_)
import Hans.Utils

import Control.Concurrent (forkIO,takeMVar,putMVar,newEmptyMVar)
import Control.Monad (forM_,mplus,guard,unless,when)
import MonadLib (BaseM(inBase),set,get)
import qualified Data.ByteString.Lazy as L
import qualified Data.Map.Strict      as Map
import qualified Data.ByteString      as S


-- Arp -------------------------------------------------------------------------

-- | A handle to a running arp layer.
type ArpHandle = Channel (Arp ())


-- | Start an arp layer.
runArpLayer :: ArpHandle -> EthernetHandle -> IO ()
runArpLayer h eth = do
  addEthernetHandler eth (EtherType 0x0806) (send h . handleIncoming)
  let i = emptyArpState h eth
  void (forkIO (loopLayer "arp" i (receive h) id))


-- External Interface ----------------------------------------------------------

-- | Lookup the hardware address associated with an IP address.
arpWhoHas :: BaseM m IO => ArpHandle -> IP4 -> m (Maybe Mac)
arpWhoHas h !ip = inBase $ do
  var <- newEmptyMVar
  send h (whoHas ip (putMVar var))
  takeMVar var

-- | Send an IP packet via the arp layer, to resolve the underlying hardware
-- addresses.
arpIP4Packet :: ArpHandle -> IP4 -> IP4 -> L.ByteString -> IO ()
arpIP4Packet h !src !dst !pkt = send h (handleOutgoing src dst pkt)

-- | Associate an address with a mac in the Arp layer.
addLocalAddress :: ArpHandle -> IP4 -> Mac -> IO ()
addLocalAddress h !ip !mac = send h (handleAddAddress ip mac)


-- Message Handling ------------------------------------------------------------

type Arp = Layer ArpState

data ArpState = ArpState
  { arpTable    :: !ArpTable
  , arpAddrs    :: !(Map.Map IP4 Mac) -- this layer's addresses
  , arpWaiting  :: !(Map.Map IP4 [Maybe Mac -> IO ()])
  , arpEthernet :: {-# UNPACK #-} !EthernetHandle
  , arpSelf     :: {-# UNPACK #-} !ArpHandle
  }

emptyArpState :: ArpHandle -> EthernetHandle -> ArpState
emptyArpState h eth = ArpState
  { arpTable    = Map.empty
  , arpAddrs    = Map.empty
  , arpWaiting  = Map.empty
  , arpEthernet = eth
  , arpSelf     = h
  }

ethernetHandle :: Arp EthernetHandle
ethernetHandle  = arpEthernet `fmap` get

addEntry :: IP4 -> Mac -> Arp ()
addEntry spa sha = do
  state <- get
  now   <- time
  let table' = addArpEntry now spa sha (arpTable state)
  table' `seq` set state { arpTable = table' }
  runWaiting spa (Just sha)

addWaiter :: IP4 -> (Maybe Mac -> IO ()) -> Arp ()
addWaiter addr cont = do
  state <- get
  set state { arpWaiting = Map.alter f addr (arpWaiting state) }
 where
  f Nothing   = Just [cont]
  f (Just ks) = Just (cont:ks)

runWaiting :: IP4 -> Maybe Mac -> Arp ()
runWaiting spa sha = do
  state <- get
  let (mb,waiting') = Map.updateLookupWithKey f spa (arpWaiting state)
        where f _ _ = Nothing
  -- run the callbacks associated with this protocol address
  let run cb = output (cb sha)
  mapM_ run (maybe [] reverse mb)
  waiting' `seq` set state { arpWaiting = waiting' }

updateExistingEntry :: IP4 -> Mac -> Arp Bool
updateExistingEntry spa sha = do
  state <- get
  let update = do
        guard (spa `Map.member` arpTable state)
        addEntry spa sha
        return True
  update `mplus` return False

localHwAddress :: IP4 -> Arp Mac
localHwAddress pa = do
  state <- get
  just (Map.lookup pa (arpAddrs state))

sendArpPacket :: ArpPacket Mac IP4 -> Arp ()
sendArpPacket msg = do
  eth <- ethernetHandle
  let frame = EthernetFrame
        { etherSource = arpSHA msg
        , etherDest   = arpTHA msg
        , etherType   = 0x0806
        }
      body = renderArpPacket renderMac renderIP4 msg
  output (sendEthernet eth frame body)

advanceArpTable :: Arp ()
advanceArpTable  = do
  now   <- time
  state <- get
  let (table', timedOut) = stepArpTable now (arpTable state)
  set state { arpTable = table' }
  forM_ timedOut $ \ x -> runWaiting x Nothing

-- | Handle a who-has request
whoHas :: IP4 -> (Maybe Mac -> IO ()) -> Arp ()
whoHas ip k = (k' =<< localHwAddress ip) `mplus` query
  where
  k' addr = output (k (Just addr))

  query = do
    advanceArpTable
    state  <- get
    case lookupArpEntry ip (arpTable state) of
      KnownAddress mac    -> k' mac
      Pending             -> addWaiter ip k
      Unknown             -> do
        let addrs = Map.toList (arpAddrs state)
            msg (spa,sha) = ArpPacket
              { arpHwType = 0x1
              , arpPType  = 0x0800
              , arpSHA    = sha
              , arpSPA    = spa
              , arpTHA    = broadcastMac
              , arpTPA    = ip
              , arpOper   = ArpRequest
              }
        now <- time
        let table' = addPending now ip (arpTable state)
        set state { arpTable = table' }
        addWaiter ip k
        mapM_ (sendArpPacket . msg) addrs
        output (delay_ 10000 (send (arpSelf state) advanceArpTable))

-- | Process an incoming arp packet
handleIncoming :: S.ByteString -> Arp ()
handleIncoming bs = do
  msg <- liftRight (parseArpPacket parseMac parseIP4 bs)
  -- ?Do I have the hardware type in ar$hrd
  -- Yes: (This check is enforced by the type system)
  --   [optionally check the hardware length ar$hln]
  --   ?Do I speak the protocol in ar$pro?
  --   Yes: (This check is also enforced by the type system)
  --     [optionally check the protocol length ar$pln]
  --     Merge_flag := false
  --     If the pair <protocol type, sender protocol address> is
  --         already in my translation table, update the sender
  --         hardware address field of the entry with the new
  --         information in the packet and set Merge_flag to true. 
  let sha = arpSHA msg
  let spa = arpSPA msg
  merge <- updateExistingEntry spa sha
  --     ?Am I the target protocol address?
  let tpa = arpTPA msg
  lha <- localHwAddress tpa
  --     Yes:
  --       If Merge_flag is false, add the triplet <protocol type,
  --           sender protocol address, sender hardware address> to
  --           the translation table.
  unless merge (addEntry spa sha)
  --       ?Is the opcode ares_op$REQUEST?  (NOW look at the opcode!!)
  --       Yes:
  when (arpOper msg == ArpRequest) $ do
  --           Swap hardware and protocol fields, putting the local
  --               hardware and protocol addresses in the sender fields.
    let msg' = msg { arpSHA = lha , arpSPA = tpa
                   , arpTHA = sha , arpTPA = spa
  --           Set the ar$op field to ares_op$REPLY
                   , arpOper = ArpReply }
  --           Send the packet to the (new) target hardware address on
  --               the same hardware on which the request was received.
    sendArpPacket msg'


-- | Handle a request to associate an ip with a mac address for a local device
handleAddAddress :: IP4 -> Mac -> Arp ()
handleAddAddress ip mac = do
  state <- get
  let addrs' = Map.insert ip mac (arpAddrs state)
  addrs' `seq` set state { arpAddrs = addrs' }


-- | Output a packet to the ethernet layer.
handleOutgoing :: IP4 -> IP4 -> L.ByteString -> Arp ()
handleOutgoing src dst body = do
  eth <- ethernetHandle
  lha <- localHwAddress src
  let frame dha = EthernetFrame
        { etherDest   = dha
        , etherSource = lha
        , etherType   = 0x0800
        }
  whoHas dst $ \ res -> case res of
    Nothing  -> return ()
    Just dha -> sendEthernet eth (frame dha) body