-- | Transmission of data via TFTP. This implements the stop-and-wait style data
-- transmission protocol.
module Network.TFTP.Protocol where

import           Network.TFTP.Types

import qualified Network.TFTP.UDPIO as UDP
import qualified Network.TFTP.Message as M

-- | XFer monad parameterised over a (MessageIO) monad.
type XFerT m address a = StateT (XFerState address) m a

-- | Execute a transfer action.
runTFTP :: (MessageIO m address) => XFerT m address result -> m result
runTFTP action = evalStateT action (XFerState 0 Nothing)

-- | A simple server action that will wait for a RRQ for its file.
offerSingleFile :: (MessageIO m address) => Maybe Int -> String -> ByteString -> XFerT m address Bool
offerSingleFile timeoutSeconds fname content = do
  req <- receive timeoutSeconds
  case req of
    Just (M.RRQ rfname mode) | rfname == fname  -> do
      case mode of
        M.NetASCII -> do
          printErr "A client requested NetASCII, this is not implemented yet."
          reply $ M.Error $ M.IllegalTFTPOperation
          return False

        M.Octet -> do
          printInfo $ printf "Accepting RRQ for %s sending %i bytes!" fname (blength content)
          resetBlockIndex
          incBlockIndex
          writeData content

    Just (M.RRQ rfname _) -> do
      printErr $ printf "Client request for %s but I can send only %s" rfname fname
      reply $ M.Error $ M.FileNotFound
      return False

    Just req -> do
      printErr $ printf "Invalid client request %s" (show req)
      reply $ M.Error $ M.IllegalTFTPOperation
      return False

    Nothing -> do
      printErr $ printf "Timeout offering single file '%s'" fname
      return False

-- | A transfer action that sends a large chunk of data via TFTP DATA messages
-- to a destination.
writeData :: (MessageIO m address) => ByteString -> XFerT m address Bool
writeData block = write maxRetries block
    where
      write retries blk = do
        replyData (btake 512 blk)
        continueAfterACK (writeNext blk) (retryThisWrite retries blk) writeFailed

      writeNext blk = do
        if blength blk >= 512 then do
                                incBlockIndex
                                write maxRetries (bdrop 512 blk)
         else do
          printInfo $ printf "Write finished"
          return True

      retryThisWrite retries blk =
          if retries == 0 then writeFailed else
              do printWarn $ printf "Retrying..."
                 write (retries - 1) blk

      writeFailed = do
        blockIdx <- getBlockIndex
        printErr $ printf "Write failed after %i retransmissions" maxRetries
        reply $ M.Error $ M.ErrorMessage "timeout"
        return False


-- | Receive the next message from the client, if the client anserws with the
-- correct ack call 'success'. If there was a timeout or the ack was for an
-- invalid index call 'retry', if an error occured call 'error
continueAfterACK success retry fail = do
  currentIdx <- getBlockIndex
  packet <- receive ackTimeOut
  case packet of
    Just (M.Error err) ->
        do printErr $ printf "Error message received:  (%s)  "  (show err)
           fail

    Just (M.ACK idx) | idx == currentIdx ->
       do printInfo $ printf "Acknowledged"
          success

    Just (M.ACK idx) | idx /= currentIdx ->
       do printWarn $ printf "ACK invalid"
          retry

    Just otherMsg ->
       do printErr $ printf "Unexpected message"
          fail

    -- this indicates a timeout
    Nothing -> retry

-- | The default number of re-transmits during 'writeData'
maxRetries :: Int
maxRetries = 30

-- | The default time 'continueAfterACK' waits for an ACK.
ackTimeOut = Just 3

-- | Internal state record for a transfer
data XFerState address = XFerState { xsBlockIndex  :: Word16
                                     -- ^ The block index of an ongoing transfer
                                   , xsFrom        :: Maybe address
                                     -- ^ Origin of the last message received
                                  }

-- | Reset the current block index for an ongoing transfer to 0
resetBlockIndex :: Monad m => XFerT m address ()
resetBlockIndex = modify (\st -> st {xsBlockIndex = 0})

-- | Read the current block index for an ongoing transfer
getBlockIndex :: Monad m => XFerT m address Word16
getBlockIndex = get >>= return . xsBlockIndex

-- | Increment the current block index for an ongoing transfer
incBlockIndex :: Monad m => XFerT m address Word16
incBlockIndex = do
   st <- get
   let i = xsBlockIndex st
       i' = i + 1
   modify (\ st -> st { xsBlockIndex = i' })
   return i'

-- | Return the origin('Address') of the message last received, or 'Nothing'
getLastPeer :: Monad m => XFerT m address (Maybe address)
getLastPeer = get >>= return . xsFrom

-- | Overwrite the origin('Address') of the message last received
setLastPeer :: (MessageIO m address) => Maybe address -> XFerT m address ()
setLastPeer addr = do
  lp <- getLastPeer
  when (lp /= addr) $ do
                     printInfo $ printf "Replacing last peer with (%s)" (show addr)
                     modify $ \st -> st { xsFrom = addr }

-- | Send a 'M.DATA' packet to the origin('Address') of the message last received with the current block index
replyData :: (MessageIO m address) => ByteString -> XFerT m address ()
replyData chunk = do
  idx <- getBlockIndex
  reply (M.DATA idx chunk)

-- | Send any 'M.Message' to the address to where the last message received from
reply :: (MessageIO m address) => M.Message -> XFerT m address ()
reply msg = do
  lp <- getLastPeer
  Just dest <- getLastPeer
  send dest msg

-- | Send any 'M.Message' to an 'Address'
send :: (MessageIO m address) => address -> M.Message -> XFerT m address ()
send dest msg = do
  let msg' = M.encode msg
  lift $ sendTo dest msg'
  printInfo $ printf "Sent message to (%s) (%i bytes)" (show dest) (blength msg')

-- | receive a message and remeber the sender for 'getLastPeer'
receive :: (MessageIO m address) => Maybe Int -> XFerT m address (Maybe M.Message)
receive timeout = do
  res <- lift (receiveFrom timeout)
  case res of
    Just (from, msg) -> do
      setLastPeer (Just from)
      let msg' = M.decode msg
      printInfo (printf "Received msg (%i bytes)" (blength msg))
      return (Just msg')
    Nothing -> do
      printWarn "Receive timeout"
      return Nothing

-- | Log debug message
printInfo :: (MessageIO m address) => String -> XFerT m address ()
printInfo = logWith debugM

-- | Log warning message
printWarn :: (MessageIO m address) => String -> XFerT m address ()
printWarn = logWith warningM

-- | Log error message
printErr :: (MessageIO m address) => String -> XFerT m address ()
printErr = logWith errorM

-- | Log message with custom priority
logWith :: (MessageIO m address) =>
           (String -> String -> IO ()) -> String -> XFerT m address ()
logWith f m = do
  la <- lift localAddress
  idx <- getBlockIndex
  from <- getLastPeer
  let m' = printf "%s @ block #%i <%s> <%s>" m idx (show la) (show from)
  liftIO (f "TFTP.Protocol" m')