{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE NamedFieldPuns    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TupleSections     #-}
{-# LANGUAGE ViewPatterns      #-}
{-|
Module      : Game.Chess.UCI
Description : Universal Chess Interface
Copyright   : (c) Mario Lang, 2021
License     : BSD3
Maintainer  : mlang@blind.guru
Stability   : experimental

The Universal Chess Interface (UCI) is a protocol for communicating with
external Chess engines.
-}
module Game.Chess.UCI (
  -- * Exceptions
  UCIException(..)
  -- * The Engine data type
, Engine, BestMove, name, author
  -- * Starting a UCI engine
, start, start'
  -- * Engine options
, Option(..), options, getOption, setOptionSpinButton, setOptionString
  -- * Manipulating the current game information
, isready
, currentPosition, setPosition, addPly, replacePly
  -- * The Info data type
, Info(..), Score(..), Bounds(..)
  -- * Searching
, search, searching
, SearchParam
, searchmoves, ponder, timeleft, timeincrement, movestogo, movetime, nodes, depth, infinite
, ponderhit
, stop
  -- * Quitting
, quit, quit'
) where

import           Control.Applicative              (Alternative (many, (<|>)),
                                                   optional)
import           Control.Concurrent               (MVar, ThreadId, forkIO,
                                                   killThread, newEmptyMVar,
                                                   putMVar, takeMVar)
import           Control.Concurrent.STM           (TChan, atomically, dupTChan,
                                                   newBroadcastTChanIO,
                                                   writeTChan)
import           Control.Exception                (Exception, handle, throwIO)
import           Control.Monad                    (forM, forever, void)
import           Control.Monad.IO.Class           (MonadIO (..))
import           Data.Attoparsec.ByteString.Char8 (Parser, anyChar, choice,
                                                   decimal, endOfInput,
                                                   manyTill, match, parseOnly,
                                                   satisfy, sepBy, sepBy1,
                                                   signed, skipSpace,
                                                   takeByteString)
import           Data.ByteString.Builder          (Builder, byteString,
                                                   hPutBuilder, intDec,
                                                   integerDec)
import           Data.ByteString.Char8            (ByteString)
import qualified Data.ByteString.Char8            as BS
import           Data.Foldable                    (Foldable (fold, foldl', toList))
import           Data.Functor                     (($>))
import           Data.HashMap.Strict              (HashMap)
import qualified Data.HashMap.Strict              as HashMap
import           Data.IORef                       (IORef, atomicModifyIORef',
                                                   newIORef, readIORef,
                                                   writeIORef)
import           Data.Ix                          (Ix (inRange))
import           Data.List                        (intersperse)
import           Data.STRef                       (modifySTRef, newSTRef,
                                                   readSTRef, writeSTRef)
import           Data.Sequence                    (Seq, ViewR ((:>)), (|>))
import qualified Data.Sequence                    as Seq
import           Data.String                      (IsString (..))
import qualified Data.Vector.Unboxed              as Unboxed
import qualified Data.Vector.Unboxed.Mutable      as Unboxed
import           Game.Chess                       (Color (..), Ply, Position,
                                                   doPly, fromUCI, legalPlies,
                                                   startpos, toFEN, toUCI,
                                                   unsafeDoPly)
import           Numeric.Natural                  (Natural)
import           System.Exit                      (ExitCode)
import           System.IO                        (BufferMode (LineBuffering),
                                                   Handle, hSetBuffering)
import           System.Process                   (CreateProcess (std_in, std_out),
                                                   ProcessHandle,
                                                   StdStream (CreatePipe),
                                                   createProcess,
                                                   getProcessExitCode, proc,
                                                   terminateProcess,
                                                   waitForProcess)
import           Time.Rational                    (KnownDivRat)
import           Time.Units                       (Microsecond, Millisecond,
                                                   Time (unTime), ms, sec,
                                                   timeout, toUnit)

type BestMove = Maybe (Ply, Maybe Ply)

data Engine = Engine {
  inH          :: Handle
, outH         :: Handle
, procH        :: ProcessHandle
, outputStrLn  :: String -> IO ()
, infoThread   :: Maybe ThreadId
, name         :: Maybe ByteString
, author       :: Maybe ByteString
, options      :: HashMap ByteString Option
, isReady      :: MVar ()
, isSearching  :: IORef Bool
, infoChan     :: TChan [Info]
, bestMoveChan :: TChan BestMove
, game         :: IORef (Position, Seq Ply)
}

-- | Set the starting position and plies of the current game.
setPosition :: (Foldable f, MonadIO m)
            => Engine -> Position -> f Ply
            -> m ()
setPosition e@Engine{game} p pl = liftIO $ do
  void $ atomicModifyIORef' game ((p, Seq.fromList $ toList pl),)
  sendPosition e

data UCIException = IllegalMove Ply deriving Show

instance Exception UCIException

data Command = Name !ByteString
             | Author !ByteString
             | Option !ByteString !Option
             | UCIOk
             | ReadyOK
             | Info [Info]
             | BestMove !BestMove
             deriving (Show)

data Info = PV !(Unboxed.Vector Ply)
          | Depth !Int
          | SelDepth !Int
          | Elapsed !(Time Millisecond)
          | MultiPV !Int
          | Score !Score (Maybe Bounds)
          | Nodes !Int
          | NPS !Int
          | TBHits !Int
          | HashFull !Int
          | CurrMove !Ply
          | CurrMoveNumber !Int
          | String !ByteString
          deriving (Eq, Show)

data Score = CentiPawns Int
           | MateIn Int
           deriving (Eq, Ord, Show)

data Bounds = UpperBound | LowerBound deriving (Eq, Show)


data Option = CheckBox Bool
            | ComboBox { comboBoxValue :: ByteString, comboBoxValues :: [ByteString] }
            | SpinButton { spinButtonValue, spinButtonMinBound, spinButtonMaxBound :: Int }
            | OString ByteString
            | Button
            deriving (Eq, Show)

instance IsString Option where
  fromString = OString . BS.pack

command :: Position -> Parser Command
command pos = skipSpace *> choice
  [ "id" `kv` name
  , "id" `kv` author
  , "option" `kv` opt
  , "uciok" $> UCIOk
  , "readyok" $> ReadyOK
  , "info" `kv` fmap Info (sepBy1 infoItem skipSpace)
  , "bestmove" `kv` ("(none)" $> BestMove Nothing <|> bestmove)
  ] <* skipSpace
 where
  name = Name <$> kv "name" takeByteString
  author = Author <$> kv "author" takeByteString
  opt = do
    void "name"
    skipSpace
    optName <- BS.pack <$> manyTill anyChar (skipSpace *> "type")
    skipSpace
    optValue <- spin <|> check <|> combo <|> str <|> button
    pure $ Option optName optValue
  check =
    fmap CheckBox $ "check" *> skipSpace *> "default" *> skipSpace *>
                    ("false" $> False <|> "true" $> True)
  spin = do
    void "spin"
    skipSpace
    value <- "default" *> skipSpace *> signed decimal <* skipSpace
    minValue <- "min" *> skipSpace *> signed decimal <* skipSpace
    maxValue <- "max" *> skipSpace *> signed decimal
    pure $ SpinButton value minValue maxValue
  combo = do
    void "combo"
    skipSpace
    def <- fmap BS.pack $ "default" *> skipSpace *> manyTill anyChar var
    (vars, lastVar) <- (,) <$> many (manyTill anyChar var)
                           <*> takeByteString
    pure $ ComboBox def (map BS.pack vars <> [lastVar])
  var = skipSpace *> "var" *> skipSpace
  str = fmap OString $
    "string" *> skipSpace *> "default" *> skipSpace *> takeByteString
  button = "button" $> Button
  infoItem = Depth <$> kv "depth" decimal
         <|> SelDepth <$> kv "seldepth" decimal
         <|> MultiPV <$> kv "multipv" decimal
         <|> kv "score" score
         <|> Nodes <$> kv "nodes" decimal
         <|> NPS <$> kv "nps" decimal
         <|> HashFull <$> kv "hashfull" decimal
         <|> TBHits <$> kv "tbhits" decimal
         <|> Elapsed . ms . fromInteger <$> kv "time" decimal
         <|> kv "pv" pv
         <|> kv "currmove" currmove
         <|> CurrMoveNumber <$> kv "currmovenumber" decimal
         <|> String <$> kv "string" takeByteString
  score = do
    s <- kv "cp" (CentiPawns <$> signed decimal)
     <|> kv "mate" (MateIn <$> signed decimal)
    b <- optional $ skipSpace *> (  UpperBound <$ "upperbound"
                                <|> LowerBound <$ "lowerbound"
                                 )
    pure $ Score s b
  pv = varToVec pos <$> sepBy mv skipSpace >>= \case
    Right v -> pure . PV $ v
    Left s  -> fail $ "Failed to parse move " <> s
  currmove = fmap (fromUCI pos) mv >>= \case
    Just m  -> pure $ CurrMove m
    Nothing -> fail "Failed to parse move"

  mv = BS.unpack . fst <$> match (sq *> sq *> optional (satisfy p)) where
    sq = satisfy (inRange ('a', 'h')) *> satisfy (inRange ('1', '8'))
    p 'q' = True
    p 'r' = True
    p 'b' = True
    p 'n' = True
    p _   = False
  bestmove = do
    m <- mv
    ponder <- optional (skipSpace *> kv "ponder" mv)
    case fromUCI pos m of
      Just m' -> case ponder of
        Nothing -> pure . BestMove . Just $ (m', Nothing)
        Just p -> case fromUCI (doPly pos m') p of
          Just p' -> pure . BestMove . Just $ (m', Just p')
          Nothing -> fail $ "Failed to parse ponder move " <> p
      Nothing -> fail $ "Failed to parse best move " <> m
  kv k v = k *> skipSpace *> v

varToVec :: Position -> [String] -> Either String (Unboxed.Vector Ply)
varToVec p xs = Unboxed.createT $ do
  v <- Unboxed.new $ length xs
  i <- newSTRef 0
  pos <- newSTRef p
  fmap (fmap (const v) . sequenceA) $ forM xs $ \x -> do
    pos' <- readSTRef pos
    case fromUCI pos' x of
      Just pl -> do
        i' <- readSTRef i
        Unboxed.write v i' pl
        modifySTRef i (+ 1)
        writeSTRef pos (unsafeDoPly pos' pl)
        pure . Right $ ()
      Nothing -> do
        pure . Left $ x

-- | Start a UCI engine with the given executable name and command line arguments.
start :: String -> [String] -> IO (Maybe Engine)
start = start' (sec 2) putStrLn

-- | Start a UCI engine with the given timeout for initialisation.
--
-- If the engine takes more then the given microseconds to answer to the
-- initialisation request, 'Nothing' is returned and the external process
-- will be terminated.
start' :: KnownDivRat unit Microsecond => Time unit -> (String -> IO ()) -> String -> [String] -> IO (Maybe Engine)
start' tout outputStrLn cmd args = do
  (Just inH, Just outH, Nothing, procH) <- createProcess (proc cmd args) {
      std_in = CreatePipe, std_out = CreatePipe
    }
  hSetBuffering inH LineBuffering
  e <- Engine inH outH procH outputStrLn Nothing Nothing Nothing HashMap.empty <$>
       newEmptyMVar <*> newIORef False <*>
       newBroadcastTChanIO <*> newBroadcastTChanIO <*>
       newIORef (startpos, Seq.empty)
  send e "uci"
  timeout tout (initialise e) >>= \case
    Just e' -> do
      tid <- forkIO . infoReader $ e'
      pure . Just $ e' { infoThread = Just tid }
    Nothing -> quit e $> Nothing

initialise :: Engine -> IO Engine
initialise c@Engine{outH, outputStrLn, game} = do
  l <- BS.hGetLine outH
  pos <- fst <$> readIORef game
  if BS.null l then initialise c else case parseOnly (command pos <* endOfInput) l of
    Left _ -> do
      outputStrLn . BS.unpack $ l
      initialise c
    Right (Name n) -> initialise (c { name = Just n })
    Right (Author a) -> initialise (c { author = Just a })
    Right (Option name opt) -> initialise (c { options = HashMap.insert name opt $ options c })
    Right UCIOk -> pure c
    Right _ -> initialise c

infoReader :: Engine -> IO ()
infoReader e@Engine{..} = forever $ do
  l <- BS.hGetLine outH
  pos <- currentPosition e
  case parseOnly (command pos <* endOfInput) l of
    Left err -> outputStrLn $ err <> ":" <> show l
    Right ReadyOK -> putMVar isReady ()
    Right (Info i) -> atomically $ writeTChan infoChan i
    Right (BestMove bm) -> do
      writeIORef isSearching False
      atomically $ writeTChan bestMoveChan bm
    Right _ -> pure ()

-- | Wait until the engine is ready to take more commands.
isready :: Engine -> IO ()
isready e@Engine{isReady} = do
  send e "isready"
  takeMVar isReady

send :: Engine -> Builder -> IO ()
send Engine{inH, procH} b = do
  hPutBuilder inH (b <> "\n")
  getProcessExitCode procH >>= \case
    Nothing -> pure ()
    Just ec -> throwIO ec

data SearchParam = SearchMoves [Ply]
                -- ^ restrict search to the specified moves only
                 | Ponder
                -- ^ start searching in pondering mode
                 | TimeLeft Color (Time Millisecond)
                -- ^ time (in milliseconds) left on the clock
                 | TimeIncrement Color (Time Millisecond)
                -- ^ time increment per move in milliseconds
                 | MovesToGo Natural
                -- ^ number of moves to the next time control
                 | MoveTime (Time Millisecond)
                 | MaxNodes Natural
                 | MaxDepth Natural
                 | Infinite
                -- ^ search until 'stop' gets called
                 deriving (Eq, Show)

searchmoves :: [Ply] -> SearchParam
searchmoves = SearchMoves

ponder :: SearchParam
ponder = Ponder

timeleft, timeincrement :: KnownDivRat unit Millisecond
                        => Color -> Time unit -> SearchParam
timeleft c = TimeLeft c . toUnit
timeincrement c = TimeIncrement c . toUnit

movestogo :: Natural -> SearchParam
movestogo = MovesToGo

movetime :: KnownDivRat unit Millisecond => Time unit -> SearchParam
movetime = MoveTime . toUnit

nodes, depth :: Natural -> SearchParam
nodes = MaxNodes
depth = MaxDepth

infinite :: SearchParam
infinite = Infinite

searching :: MonadIO m => Engine -> m Bool
searching Engine{isSearching} = liftIO $ readIORef isSearching

-- | Instruct the engine to begin searching.
search :: MonadIO m
       => Engine -> [SearchParam]
       -> m (TChan BestMove, TChan [Info])
search e@Engine{isSearching} params = liftIO $ do
  chans <- atomically $ (,) <$> dupTChan (bestMoveChan e)
                            <*> dupTChan (infoChan e)
  send e . fold . intersperse " " $ "go" : foldr build mempty params
  writeIORef isSearching True
  pure chans
 where
  build (SearchMoves plies) xs = "searchmoves" : (fromString . toUCI <$> plies) <> xs
  build Ponder xs = "ponder" : xs
  build (TimeLeft White (floor . unTime -> x)) xs = "wtime" : integerDec x : xs
  build (TimeLeft Black (floor . unTime -> x)) xs = "btime" : integerDec x : xs
  build (TimeIncrement White (floor . unTime -> x)) xs = "winc" : integerDec x : xs
  build (TimeIncrement Black (floor . unTime -> x)) xs = "binc" : integerDec x : xs
  build (MovesToGo x) xs = "movestogo" : naturalDec x : xs
  build (MoveTime (floor . unTime -> x)) xs = "movetime" : integerDec x : xs
  build (MaxNodes x) xs = "nodes" : naturalDec x : xs
  build (MaxDepth x) xs = "depth" : naturalDec x : xs
  build Infinite xs = "infinite" : xs
  naturalDec = integerDec . toInteger

-- | Switch a ponder search to normal search when the pondered move was played.
ponderhit :: MonadIO m => Engine -> m ()
ponderhit e = liftIO $ send e "ponderhit"

-- | Stop a search in progress.
stop :: MonadIO m => Engine -> m ()
stop e = liftIO $ send e "stop"

getOption :: ByteString -> Engine -> Maybe Option
getOption n = HashMap.lookup n . options

-- | Set a spin option to a particular value.
--
-- Bounds are validated.  Make sure you don't set a value which is out of range.
setOptionSpinButton :: MonadIO m => ByteString -> Int -> Engine -> m Engine
setOptionSpinButton n v c
  | Just (SpinButton _ minValue maxValue) <- getOption n c
  , inRange (minValue, maxValue) v
  = liftIO $ do
    send c $ "setoption name " <> byteString n <> " value " <> intDec v
    pure $ c { options = HashMap.update (set v) n $ options c }
  | otherwise
  = error "No option with that name or value out of range"
 where
  set val opt@SpinButton{} = Just $ opt { spinButtonValue = val }

setOptionString :: MonadIO m => ByteString -> ByteString -> Engine -> m Engine
setOptionString n v e = liftIO $ do
  send e $ "setoption name " <> byteString n <> " value " <> byteString v
  pure $ e { options = HashMap.update (set v) n $ options e }
 where
  set val _ = Just $ OString val

-- | Return the final position of the currently active game.
currentPosition :: MonadIO m => Engine -> m Position
currentPosition Engine{game} = liftIO $
  uncurry (foldl' doPly) <$> readIORef game

-- | Add a 'Move' to the game history.
--
-- This function checks if the move is actually legal, and throws a 'UCIException'
-- if it isn't.
addPly :: MonadIO m => Engine -> Ply -> m ()
addPly e@Engine{game} m = liftIO $ do
  pos <- currentPosition e
  if m `notElem` legalPlies pos then throwIO $ IllegalMove m else do
    atomicModifyIORef' game $ \g -> (fmap (|> m) g, ())
    sendPosition e

replacePly :: MonadIO m => Engine -> Ply -> m ()
replacePly e@Engine{game} pl = liftIO $ do
  atomicModifyIORef' game $ \g ->
    (fmap (\xs -> case Seq.viewr xs of xs' :> _ -> xs') g, ())
  addPly e pl

sendPosition :: Engine -> IO ()
sendPosition e@Engine{game} = readIORef game >>= send e . cmd where
  cmd (p, h) = fold . intersperse " " $
    "position" : "fen" : fromString (toFEN p) : line (toList h)
  line [] = []
  line h  = "moves" : (fromString . toUCI <$> h)

-- | Quit the engine.
quit :: MonadIO m => Engine -> m (Maybe ExitCode)
quit = quit' (sec 1)

quit' :: (KnownDivRat unit Microsecond, MonadIO m)
      => Time unit -> Engine -> m (Maybe ExitCode)
quit' t e@Engine{procH, infoThread} = liftIO $ (pure . Just) `handle` do
  maybe (pure ()) killThread infoThread
  send e "quit"
  timeout t (waitForProcess procH) >>= \case
    Just ec -> pure $ Just ec
    Nothing -> terminateProcess procH $> Nothing
