{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
module Network.Legion.Discovery (
  main
) where

import Canteven.HTTP (requestLogging, logExceptionsAndContinue,
  DecodeResult(Unsupported, BadEntity, Ok), FromEntity, decodeEntity)
import Canteven.Log.MonadLog (getCantevenOutput)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Aeson (encode, object, (.=), FromJSON, parseJSON,
  Value(Object), (.:), eitherDecode)
import Data.ByteString (ByteString, hGetContents)
import Data.Conduit ((=$=), runConduit)
import Data.GraphViz (graphvizWithHandle, GraphvizCommand(Dot),
  GraphvizOutput(Svg))
import Data.GraphViz.Printing (renderDot, toDot)
import Data.GraphViz.Types.Canonical (DotGraph)
import Data.Map (Map)
import Data.Monoid ((<>))
import Data.String (IsString)
import Data.Text (pack, unpack)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Time (getCurrentTime)
import Data.Version (showVersion)
import Distribution.Text (simpleParse)
import Distribution.Version (VersionRange, anyVersion)
import Network.HTTP.Types (badRequest400, unsupportedMediaType415,
  noContent204)
import Network.Legion (forkLegionary, Runtime, makeRequest,
  search, SearchTag(SearchTag), IndexRecord, irKey, PartitionKey,
  newMemoryPersistence, Persistence)
import Network.Legion.Config (parseArgs)
import Network.Legion.Discovery.App (Input(GetRange, Ping, GetService,
  GetRequests), Output(InstanceList, PingResponse, ServiceResponse),
  ServiceId(ServiceId), toKey, Time(Time), unServiceAddr, version,
  ServiceAddr(ServiceAddr), InstanceInfo, Service, instances, name,
  unServiceId, State, Client(Client), cName, cVersion)
import Network.Legion.Discovery.Config (servicePort)
import Network.Legion.Discovery.Graphviz (toDotGraph)
import Network.Legion.Discovery.HttpError (HttpError)
import Network.Legion.Discovery.LIO (runLIO, LIO)
import Network.Legion.Discovery.UserAgent (UserAgent(UAProduct), Product(
  Product), productName, productVersion, parseUserAgent)
import Network.Wai (Middleware, modifyResponse)
import Network.Wai.Middleware.AddHeaders (addHeaders)
import Network.Wai.Middleware.StripHeaders (stripHeader)
import Web.Scotty.Format.Trans (respondTo, format)
import Web.Scotty.Resource.Trans (resource, get, post)
import Web.Scotty.Trans (scottyT, middleware, ScottyT, param, setHeader,
  raw, status, text, ActionT, ScottyError, body, header)
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Conduit.List as CL
import qualified Data.Map as Map
import qualified Data.Text.Encoding as TE
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TLE
import qualified Network.Legion.Discovery.Config as C
import qualified Paths_legion_discovery as P

main :: IO ()
main = do
  (settings, startupMode, config) <- parseArgs
  logging <- getCantevenOutput (C.logging config)
  persist <- newMemoryPersistence :: IO (Persistence Input Output State)
  runLIO logging $ do
    legion <- forkLegionary persist settings startupMode
    scottyT (servicePort config) (runLIO logging) $ do
      middleware
        $ requestLogging logging
        . setServer "legion-discovery"
        . logExceptionsAndContinue logging
      webService legion


{- | The web service endpoint definitions.  -}
webService :: Runtime Input Output -> ScottyT HttpError LIO ()
webService runtime = do
  resource "/v1/ping/:serviceId/:version" $
    post $
      simpleParse <$> param "version" >>= \case
        Nothing -> do
          status badRequest400
          text "Invalid version."
        Just ver -> do
          serviceId <- ServiceId <$> param "serviceId"
          now <- Time <$> liftIO getCurrentTime
          withEntity (\ PingRequest {serviceAddress} -> do
              let req = Ping now serviceId ver serviceAddress
              makeRequest runtime (toKey serviceId) req >>= \case
                PingResponse -> status noContent204
                InstanceList _ -> fail "Invalid runtime response."
                ServiceResponse _ -> fail "Invalid runtime response."
            )
  resource "/v1/services" $
    get $ do
      now <- Time <$> liftIO getCurrentTime
      list <- runConduit (
          search runtime (SearchTag "" Nothing)
          =$= CL.mapMaybeM (fillServiceInfo now)
          =$= CL.consume
        )
      respondTo $
        format serviceListCT $ do
          setHeader "content-type" serviceListCT
          raw . encode $ object [
              unServiceId (name service) .= encodeInstances (instances service)
              | service <- list
            ]
  let
    getGraph :: (MonadIO m, ScottyError e) => ActionT e m (DotGraph TL.Text)
    getGraph = 
      toDotGraph <$> runConduit (
          search runtime (SearchTag "" Nothing)
          =$= CL.mapMaybeM fillClientInfo
          =$= CL.consume
        )

    renderSvg graph = format svgCT $ do
      setHeader "content-type" svgCT
      bytes <- liftIO $ graphvizWithHandle Dot graph Svg hGetContents
      raw (LBS.fromStrict bytes)

    renderGraphviz graph = format graphvizCT $ do
      setHeader "content-type" graphvizCT
      text (renderDot (toDot graph))

  resource "/v1/graph" $
    get $ do
      graph <- getGraph
      respondTo $ do
        renderSvg graph
        renderGraphviz graph
  let
    queryResource :: (MonadIO m, ScottyError e)
      => ServiceId
      -> VersionRange
      -> ActionT e m ()
    queryResource serviceId range = 
      withClients $ \clients -> do
        now <- Time <$> liftIO getCurrentTime
        respondInstances =<<
          getInstances
            runtime
            (toKey serviceId)
            (GetRange clients serviceId now range)
  resource "/v1/services/:serviceId" $
    get $ do
      serviceId <- ServiceId <$> param "serviceId"
      queryResource serviceId anyVersion
  resource "/v1/services/:serviceId/:versionRange" $
    get $
      simpleParse <$> param "versionRange" >>= \case
        Nothing -> do
          status badRequest400
          text "Invalid version range."
        Just range -> do
          serviceId <- ServiceId <$> param "serviceId"
          queryResource serviceId range
  where
    {- |
      Parse the list of clients from the User-Agent header and return
      the resulting 'ActionT', or else return an 'ActionT' that returns
      a @400 Bad Request@ if the User-Agent header is missing or invalid.
    -}
    withClients :: (Monad m, ScottyError e)
      => ([Client] -> ActionT e m ())
      -> ActionT e m ()
    withClients f =
      fmap parseUserAgent <$> headerBS "user-agent" >>= \case
        Nothing -> do
          status badRequest400
          text "Missing User-Agent."
        Just (Left err) -> do
          status badRequest400
          text $ "Invalid User-Agent: " <> TL.pack err
        Just (Right uas) ->
          let
            clients = [
                Client {
                    cName = ServiceId (decodeUtf8 name),
                    cVersion = version
                  }
                | UAProduct Product {productName, productVersion} <- uas
                , let
                    vstring = unpack . decodeUtf8 <$> productVersion
                    (name, version) = case simpleParse =<< vstring of
                      Nothing -> (
                          productName <> maybe "" ("/" <>) productVersion,
                          Nothing
                        )
                      Just v -> (productName, Just v)
              ]
          in f clients


    {- | Get the service info for the serivce listing.  -}
    fillServiceInfo :: (MonadIO m, ScottyError e)
      => Time
      -> IndexRecord
      -> ActionT e m (Maybe Service)
    fillServiceInfo now ir = getService runtime (irKey ir) (GetService now)

    {- | Get the client info for a service. -}
    fillClientInfo :: (MonadIO m, ScottyError e)
      => IndexRecord
      -> ActionT e m (Maybe Service)
    fillClientInfo ir = getService runtime (irKey ir) GetRequests

    {- | Send a response containing the service instance list. -}
    respondInstances :: (Monad m, ScottyError e)
      => Map ServiceAddr InstanceInfo
      -> ActionT e m ()
    respondInstances is = do
      setHeader "content-type" instanceListCT
      (raw . encode . encodeInstances) is

    {- | Encode instances into a JSON object -}
    encodeInstances :: Map ServiceAddr InstanceInfo -> Value
    encodeInstances instances = object [
        unServiceAddr addr .= object [
            "version" .= showVersion (version info)
          ]
        | (addr, info) <- Map.toList instances
      ]

{- | Send a legion request that returns an InstanceList response.  -}
getInstances :: (MonadIO io)
  => Runtime Input Output
  -> PartitionKey
  -> Input
  -> io (Map ServiceAddr InstanceInfo)
getInstances runtime key input =
  makeRequest runtime key input >>= \case
    InstanceList instances -> return instances
    PingResponse -> fail "Invalid runtime response."
    ServiceResponse _ -> fail "Invalid runtime response."


{- | Send a legion request that returns a 'ServiceResposne' response. -}
getService :: (MonadIO io)
  => Runtime Input Output
  -> PartitionKey
  -> Input
  -> io (Maybe Service)
getService runtime key input =
  makeRequest runtime key input >>= \case
    ServiceResponse service -> return service
    PingResponse -> fail "Invalid runtime response."
    InstanceList _ -> fail "Invalid runtime response."


{- | Set the server header.  -}
setServer :: String -> Middleware
setServer serviceName = addServerHeader . stripServerHeader
  where
    {- |
      Strip the server header
    -}
    stripServerHeader :: Middleware
    stripServerHeader = modifyResponse (stripHeader "Server")

    {- |
      Add our own server header.
    -}
    addServerHeader :: Middleware
    addServerHeader = addHeaders [("Server", serverValue)]

    {- |
      The value of the @Server:@ header.
    -}
    serverValue =
      TE.encodeUtf8 (pack (serviceName ++ "/" ++ showVersion P.version))


{- | The service instance list content type. -}
instanceListCT :: (IsString a) => a
instanceListCT = "application/vnd.legion-discovery.instance-list+json"


{- | The known service list content type. -}
serviceListCT :: (IsString a) => a
serviceListCT = "application/vnd.legion-discovery.service-list+json"


{- | The content type of unrendered dot graphs. -}
graphvizCT :: (IsString a) => a
graphvizCT = "text/vnd.graphviz"


{- | The content type of svg data. -}
svgCT :: (IsString a) => a
svgCT = "image/svg+xml"


{- | Scotty shorthand for getting and decoding an entity. -}
withEntity :: (FromEntity a, MonadIO m, ScottyError e)
  => (a -> ActionT e m ())
  -> ActionT e m ()
withEntity f =
  decodeEntity <$> headerLBS "content-type" <*> body >>= \case
    Unsupported -> status unsupportedMediaType415
    BadEntity why -> do
      status badRequest400
      text (TL.pack why)
    Ok b -> f b


{- | Get a header as a 'Date.ByteString.Lazy.ByteString'.  -}
headerLBS :: (ScottyError e, Monad m)
  => TL.Text
  -> ActionT e m (Maybe LBS.ByteString)
headerLBS headerName = fmap TLE.encodeUtf8 <$> header headerName


{- | Decode a ping request entity. -}
newtype PingRequest = PingRequest {
    serviceAddress :: ServiceAddr
  }
instance FromJSON PingRequest where
  parseJSON (Object o) =
    PingRequest . ServiceAddr <$> o .: "serviceAddress"
  parseJSON v = fail
    $ "Can't parse PingRequest from: " ++ show v
instance FromEntity PingRequest where
  decodeEntity (Just PingRequestCT) bytes =
    case eitherDecode bytes of
      Left err -> BadEntity err
      Right req -> Ok req
  decodeEntity _ _ = Unsupported


{- | The content type for ping requests. -}
pattern PingRequestCT :: (IsString a, Eq a) => a
pattern PingRequestCT = "application/vnd.legion-discovery.ping-request+json"


{- |
  Like 'Web.Scotty.Trans.header', except returns a 'ByteString' instead of
  'TL.Text'.
-}
headerBS :: (Monad m, ScottyError e)
  => TL.Text
  -> ActionT e m (Maybe ByteString)
headerBS name = fmap (encodeUtf8 . TL.toStrict) <$> header name