{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE LambdaCase           #-}
{-# LANGUAGE OverloadedStrings    #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE UndecidableInstances #-}
{-|
Module      : Servant.Lint
Description : Lint Servant API types for common problems
Copyright   : (c) Isaac Shapira, 2024
License     : BSD-3-Clause
Maintainer  : isaac.shapira@platonic.systems
Stability   : experimental

This module provides linting functionality for Servant API types to detect:

* Ambiguous route overlaps
* Duplicate type acceptance (same type in multiple combinators)
* Duplicate QueryParam names
* Invalid HTTP method/response code combinations
* Incorrect ReqBody placement

Use 'lintAPI' in your test suite to catch these issues early:

@
import Servant.Lint

type MyAPI = "users" :> Capture "id" Int :> Get '[JSON] User

main :: IO ()
main = lintAPI @MyAPI
@

For more control over error handling, use 'lintAPI'' to get the errors as a list,
or 'printLintAPI' to print them to stdout with colors.
-}
module Servant.Lint
  ( lintAPI
  , lintAPI'
  , printLintAPI
  , Error(..)
  , Path(..)
  , Lintable(..)
  , Ambiguity(..)
  , unlinesChunks
  ) where

import           Data.ByteString           (ByteString)
import           Data.Containers.ListUtils (nubOrd)
import           Data.Kind                 (Constraint, Type)
import           Data.List                 ((\\))
import           Data.Maybe                (mapMaybe, maybeToList)
import           Data.Proxy                (Proxy (Proxy))
import qualified Data.Text                 as Text (Text, intercalate, pack, unpack)
import           Data.Text.Encoding        (decodeUtf8)
import           Data.Typeable             (TypeRep, Typeable, typeRep)
import           GHC.Generics              (Generic)
import           GHC.TypeLits              (KnownNat, KnownSymbol, Symbol, natVal,
                                            symbolVal)
import           Servant.API               (AuthProtect, Capture', CaptureAll, Description,
                                            NoContent, NoContentVerb,
                                            QueryParam, QueryParams,
                                            ReflectMethod (..), ReqBody',
                                            Summary, Verb, type (:<|>),
                                            type (:>))
import           Servant.Multipart         (MultipartForm)
import           Text.Colour               (Chunk, TerminalCapabilities (..),
                                            bold, renderChunksText)
import           Text.Colour.Chunk         (chunk, fore, red)

-- | A term level representation of the API
-- Its defined recursive to flatten the API to a list of routes instead of a tree
type Path :: Type
data Path
  = PPath String !Path
  | PCapture String TypeRep !Path
  | PCaptureAll String TypeRep !Path
  | PQueryParam String TypeRep !Path
  | PQueryParams String TypeRep !Path
  | PReqBody TypeRep !Path
  | PVerb ByteString Integer TypeRep
  deriving (Eq, Ord)

instance Show Path where
  show (PPath s p)           = "\"" <> s <> "\" :> " <> show p
  show (PCapture s t p)      = "Capture \"" <> s <> "\" " <> show t <> " :> " <> show p
  show (PCaptureAll s t p)   = "CaptureAll \"" <> s <> "\" " <> show t <> " :> " <> show p
  show (PQueryParam s t p)   = "QueryParam \"" <> s <> "\" " <> show t <> " :> " <> show p
  show (PQueryParams s t p)  = "QueryParams \"" <> s <> "\" " <> show t <> " :> " <> show p
  show (PReqBody t p)        = "ReqBody _ _ " <> show t <> " :> " <> show p
  show (PVerb method code t) = "Verb '" <> Text.unpack (decodeUtf8 method) <> " " <> show code <> " " <> show t

-- | The Lintable type class describes how to go from a Servant Combinator to a @Path@
-- This is essentially a function from `Type -> [Path]`
-- If you have custom Servant Combinators you may need to add an instance of Lintable for your Combinator, typically ignoring the custom Combinator.
type Lintable :: Type -> Constraint
class Lintable a where
  paths :: [Path]

instance (Lintable a, Lintable b) => Lintable (a :<|> b) where
  paths = paths @a <> paths @b

instance (Lintable b, KnownSymbol hint, Typeable a) => Lintable (Capture' _mods hint a :> b) where
  paths = PCapture (symbolVal (Proxy @hint)) (typeRep (Proxy @a))<$> paths @b

instance (Lintable b, KnownSymbol hint, Typeable a) => Lintable (CaptureAll hint a :> b) where
  paths = PCaptureAll (symbolVal (Proxy @hint)) (typeRep (Proxy @a)) <$> paths @b

instance (Lintable b, KnownSymbol hint, Typeable a) => Lintable (QueryParam hint a :> b) where
  paths = PQueryParam (symbolVal (Proxy @hint)) (typeRep (Proxy @a)) <$> paths @b

instance (Lintable b, KnownSymbol hint, Typeable a) => Lintable (QueryParams hint a :> b) where
  paths = PQueryParam (symbolVal (Proxy @hint)) (typeRep (Proxy @a)) <$> paths @b

instance (Lintable b, Typeable a) => Lintable (ReqBody' _mods _ms a :> b) where
  paths = PReqBody (typeRep (Proxy @a)) <$> paths @b

instance (KnownSymbol a, Lintable b) => Lintable (a :> b) where
  paths = PPath (symbolVal (Proxy @a)) <$> paths @b

instance Lintable b => Lintable (Summary _a :> b) where
  paths = paths @b

instance Lintable b => Lintable (Description _a :> b) where
  paths = paths @b

instance (Lintable b) => Lintable (MultipartForm tag a :> b) where
  paths = paths @b

instance (Lintable b) => Lintable (AuthProtect (a :: Symbol) :> b) where
  paths = paths @b

instance {-# OVERLAPPABLE #-}
  ( ReflectMethod method
  ) => Lintable (NoContentVerb method) where
  paths = pure $ PVerb (reflectMethod (Proxy @method)) 204 (typeRep (Proxy @NoContent))

instance {-# OVERLAPPABLE #-}
  ( Typeable ret
  , ReflectMethod method
  , KnownNat code
  ) => Lintable (Verb method code _cs ret) where
  paths = pure $ PVerb (reflectMethod (Proxy @method)) (natVal (Proxy @code)) (typeRep (Proxy @ret))

-- | This is a striped down version of the @Path@ focusing on removing details that ambiguate routes
type Ambiguity :: Type
data Ambiguity
  = ACapture
  | ACaptureAll
  | AQueryParam String
  | APath String
  | AReqBody
  | AVerb ByteString Integer
  deriving (Eq, Ord, Show, Generic)

-- | Non lawful Eq check
(=!=) :: Ambiguity -> Ambiguity -> Bool
ACapture =!= _                   = True
_ =!= ACapture                   = True
AReqBody =!= AReqBody            = True
APath s =!= APath s'             = s == s'
AVerb m c =!= AVerb m' c'        = m == m' && c == c'
AQueryParam s =!= AQueryParam s' = s == s'
_ =!= _                          = False

-- | Non lawful Eq check
(=!!=) :: [Ambiguity] -> [Ambiguity] -> Bool
a@(ACaptureAll : _) =!!= b@(_:_) = last a =!= last b
a@(_:_) =!!= b@(ACaptureAll:_)   = last a =!= last b
(a:as) =!!= (b:bs)               = a =!= b && as =!!= bs
[] =!!= []                       = True
_ =!!= _                         = False

ambiguity :: Path -> [Ambiguity]
ambiguity = \case
  PPath s p -> APath s : ambiguity p
  PCapture _ _ p -> ACapture : ambiguity p
  PCaptureAll _ _ p -> ACaptureAll : ambiguity p
  PReqBody _ p -> AReqBody : ambiguity p
  PQueryParam s _ p -> AQueryParam s : ambiguity p
  PQueryParams s _ p -> AQueryParam s : ambiguity p
  PVerb method code _ty -> [AVerb method code]

-- | Extract TypeReps from a Path for duplicate detection
pathTypes :: Path -> [TypeRep]
pathTypes = \case
  PPath _ p -> pathTypes p
  PCapture _ ty p -> ty : pathTypes p
  PCaptureAll _ ty p -> ty : pathTypes p
  PReqBody ty p -> ty : pathTypes p
  PQueryParam _ ty p -> ty : pathTypes p
  PQueryParams _ ty p -> ty : pathTypes p
  PVerb _ _ _ -> []

-- | Build a route string with emoji pointers for duplicate types
showRouteWithDuplicateHighlights :: [TypeRep] -> Path -> Text.Text
showRouteWithDuplicateHighlights dupTypes = go
  where
    go = \case
      PPath s p -> "\"" <> Text.pack s <> "\" :> " <> go p
      PCapture s ty p ->
        let captureStr = "Capture \"" <> Text.pack s <> "\" " <> Text.pack (show ty)
            highlighted = if ty `elem` dupTypes then captureStr <> " 👈" else captureStr
        in highlighted <> " :> " <> go p
      PCaptureAll s ty p ->
        let captureStr = "CaptureAll \"" <> Text.pack s <> "\" " <> Text.pack (show ty)
            highlighted = if ty `elem` dupTypes then captureStr <> " 👈" else captureStr
        in highlighted <> " :> " <> go p
      PReqBody ty p ->
        let reqBodyStr = "ReqBody _ _ " <> Text.pack (show ty)
            highlighted = if ty `elem` dupTypes then reqBodyStr <> " 👈" else reqBodyStr
        in highlighted <> " :> " <> go p
      PQueryParam s ty p ->
        let queryStr = "QueryParam \"" <> Text.pack s <> "\" " <> Text.pack (show ty)
            highlighted = if ty `elem` dupTypes then queryStr <> " 👈" else queryStr
        in highlighted <> " :> " <> go p
      PQueryParams s ty p ->
        let queryStr = "QueryParams \"" <> Text.pack s <> "\" " <> Text.pack (show ty)
            highlighted = if ty `elem` dupTypes then queryStr <> " 👈" else queryStr
        in highlighted <> " :> " <> go p
      PVerb method code ty -> "Verb '" <> decodeUtf8 method <> " " <> Text.pack (show code) <> " " <> Text.pack (show ty)

-- | Extract QueryParam names from a Path for duplicate name detection
pathQueryParamNames :: Path -> [String]
pathQueryParamNames = \case
  PPath _ p -> pathQueryParamNames p
  PCapture _ _ p -> pathQueryParamNames p
  PCaptureAll _ _ p -> pathQueryParamNames p
  PReqBody _ p -> pathQueryParamNames p
  PQueryParam name _ p -> name : pathQueryParamNames p
  PQueryParams name _ p -> name : pathQueryParamNames p
  PVerb _ _ _ -> []

-- | Build a route string with emoji pointers for duplicate QueryParam names
showRouteWithDuplicateQueryParamHighlights :: [String] -> Path -> Text.Text
showRouteWithDuplicateQueryParamHighlights dupNames = go
  where
    go = \case
      PPath s p -> "\"" <> Text.pack s <> "\" :> " <> go p
      PCapture s ty p -> "Capture \"" <> Text.pack s <> "\" " <> Text.pack (show ty) <> " :> " <> go p
      PCaptureAll s ty p -> "CaptureAll \"" <> Text.pack s <> "\" " <> Text.pack (show ty) <> " :> " <> go p
      PReqBody ty p -> "ReqBody _ _ " <> Text.pack (show ty) <> " :> " <> go p
      PQueryParam name ty p ->
        let queryStr = "QueryParam \"" <> Text.pack name <> "\" " <> Text.pack (show ty)
            highlighted = if name `elem` dupNames then queryStr <> " 👈" else queryStr
        in highlighted <> " :> " <> go p
      PQueryParams name ty p ->
        let queryStr = "QueryParams \"" <> Text.pack name <> "\" " <> Text.pack (show ty)
            highlighted = if name `elem` dupNames then queryStr <> " 👈" else queryStr
        in highlighted <> " :> " <> go p
      PVerb method code ty -> "Verb '" <> decodeUtf8 method <> " " <> Text.pack (show code) <> " " <> Text.pack (show ty)

-- | Check for duplicate QueryParam names
checkForDuplicateQueryParamNames :: Path -> Maybe Error
checkForDuplicateQueryParamNames p =
  case duplicatedNames of
    [] -> Nothing
    dups -> Just $ Error $
      [ chunk "Route has multiple QueryParam with the same name: "
      , bold $ chunk $ Text.intercalate ", " (Text.pack <$> dups)
      , chunk ". QueryParam names must be unique within a route:"
      ] : [[chunk $ "\t" <> showRouteWithDuplicateQueryParamHighlights dups p]]
  where
  names = pathQueryParamNames p
  duplicatedNames = names \\ nubOrd names

checkForDuplicates :: Path -> Maybe Error
checkForDuplicates p =
  case duplicatedTypes of
    [] -> Nothing
    dups -> Just $ Error $
      [ chunk "Route accepts the same type multiple times: "
      , bold $ chunk $ Text.intercalate ", " (Text.pack . show <$> dups)
      , chunk ". This doesn't guarantee argument order and can lead to ambiguous behavior:"
      ] : [[chunk $ "\t" <> showRouteWithDuplicateHighlights dups p]]
  where
  types = pathTypes p
  duplicatedTypes = types \\ nubOrd types

-- | Pretty errors via @Text.Colour@
newtype Error = Error { toChunks :: [[Chunk]] }
  deriving newtype (Eq, Show, Semigroup, Monoid)

elem' :: [Ambiguity] -> [[Ambiguity]] -> Bool
elem' = any . (=!!=)

-- | Pass your API type for lint errors as Chunks
-- Chunks are colored terminal bits from @Text.Colour@ for making
-- pretty errors.
lintAPI' :: forall api. Lintable api => [Error]
lintAPI' = case go psAll of [x,_] -> [x]; x -> x
  where
  psAll = paths @api
  go [] = []
  go (p:ps) =
    let ambiguities = [printAmbiguity psAll p | ambiguity p `elem'` (ambiguity <$> deleteFirst p psAll)]
        badReturns  = mapMaybe (badReturn psAll p) psAll
        duplicates  = checkForDuplicates p
        duplicateQueryNames = checkForDuplicateQueryParamNames p
    in ambiguities <> badReturns <> maybeToList duplicates <> maybeToList duplicateQueryNames <> go ps

-- | Pass your API type for lint errors thrown in IO
-- This is typically useful for testing
lintAPI :: forall api. Lintable api => IO ()
lintAPI = case lintAPI' @api of
    [] -> pure ()
    ls -> error $ Text.unpack $ renderChunksText With24BitColours . unlinesChunks $ unlinesChunks . toChunks <$> ls

-- | Pass your API type for lint via @putStrLn@ in stdout
printLintAPI :: forall api. Lintable api => IO ()
printLintAPI = case lintAPI' @api of
    [] -> pure ()
    ls -> putStrLn $ Text.unpack $ renderChunksText With24BitColours . unlinesChunks $ unlinesChunks . toChunks <$> ls

deleteFirst :: Eq t => t -> [t] -> [t]
deleteFirst _ [] = []
deleteFirst a (b:bc) | a == b    = bc
                     | otherwise = b : deleteFirst a bc

badReturn :: [Path] -> Path -> Path -> Maybe Error
badReturn psAll c = \case
  PPath _ p -> badReturn psAll c p
  PCapture _ _ p -> badReturn psAll c p
  PCaptureAll _ _ p -> badReturn psAll c p
  PQueryParam _ _ p -> badReturn psAll c p
  PVerb _method 500 _ty -> Just $ Error $
     [ chunk "Bad verb, you should never intentionally return 500 as part of your API:"
     ] : (badReturnColor <$> psAll)
  PVerb _method code ty | code /= 204 && ty == typeRep (Proxy @NoContent) -> Just $ Error $
     [ chunk "Bad verb, NoContent must use HTTP Status Code 204, not "
     , bold $ chunk $ Text.pack $ show code <> ":"
     ] : (badReturnColor <$> psAll)
  PReqBody _ (PVerb "GET" _ _) -> Just $ Error $
     [ chunk "Bad verb, do not use ReqBody in a GET request, Http 1.1 says its meaningless"
     ] : (badReturnColor <$> psAll)
  PReqBody _ (PVerb {}) -> Nothing
  PReqBody _ _ -> Just $ Error $
     [ chunk "ReqBody must be the last combinator before the Verb"
     ] : (badReturnColor <$> psAll)
  _ -> Nothing
  where
    badReturnColor :: Path -> [Chunk]
    badReturnColor p' = pure $
      if c == p'
      then fore red $ chunk $ ("\t" <>) $ Text.pack $ show p' <> " 👈"
      else chunk $ ("\t" <>) $ Text.pack $ show p'

printAmbiguity :: [Path] -> Path -> Error
printAmbiguity ps p = Error $
    [ chunk "Ambiguous with "
    , bold $ chunk $ Text.pack (show p) <> ":"
    ] : (overlappingColor <$> ps)
  where
    overlappingColor p' = pure $
      if ambiguity p =!!= ambiguity p'
      then fore red $ chunk $ ("\t" <>) $ Text.pack $ show p' <> " 👈"
      else chunk $ ("\t" <>) $ Text.pack $ show p'

-- | Exported for testing only
unlinesChunks :: [[Chunk]] -> [Chunk]
unlinesChunks = concatMap (<> [chunk "\n"])
