module Search.MoveOrdering (getSortedMoves) where

import           AppPrelude

import           Evaluation.Evaluation
import           Models.Command           (EngineOptions)
import           Models.Move
import           Models.Piece
import           Models.Position
import           Models.Score
import           MoveGen.MakeMove         (makeMove)
import           MoveGen.MoveQueries
import           MoveGen.PieceCaptures    (allCaptures)
import           MoveGen.PieceQuietMoves  (allQuietMoves)
import           MoveGen.PositionQueries
import qualified Utils.KillersTable       as KillersTable
import           Utils.KillersTable       (KillersTable)
import qualified Utils.TranspositionTable as TTable
import           Utils.TranspositionTable (TTable)



-- Move Ordering:
-- - Transposition table (PV / Refutation) move
-- - Winning captures (SEE >= 0) (Ordered by SEE)
-- - 2 Killer moves
-- - Quiet moves (Ordered by Static Eval)
-- - Losing captures (SEE < 0) (Ordered by SEE)

-- Reduced moves:
-- Late killers / quiet moves

getSortedMoves :: (?killersTable :: KillersTable, ?tTable :: TTable,
                  ?opts :: EngineOptions)
  => Depth -> Ply -> Position -> IO ([Move], [Move])
getSortedMoves :: (?killersTable::KillersTable, ?tTable::TTable,
 ?opts::EngineOptions) =>
Depth -> Depth -> Position -> IO ([Move], [Move])
getSortedMoves !Depth
depth !Depth
ply Position
pos = do
  [Move]
ttMove      <- Maybe Move -> [Element (Maybe Move)]
Maybe Move -> [Move]
forall mono. MonoFoldable mono => mono -> [Element mono]
toList (Maybe Move -> [Move]) -> IO (Maybe Move) -> IO [Move]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (?tTable::TTable, ?opts::EngineOptions) => ZKey -> IO (Maybe Move)
ZKey -> IO (Maybe Move)
TTable.lookupBestMove (Position -> ZKey
getZobristKey Position
pos)
  [Move]
killerMoves <- (?killersTable::KillersTable) => Depth -> Position -> IO [Move]
Depth -> Position -> IO [Move]
getKillers Depth
ply Position
pos
  let
    bestMoves :: [Move]
bestMoves =
         [Move]
ttMove
      [Move] -> [Move] -> [Move]
forall a. Semigroup a => a -> a -> a
<> (Element [Move] -> Bool) -> [Move] -> [Move]
forall seq. IsSequence seq => (Element seq -> Bool) -> seq -> seq
filter (Element [Move] -> [Element [Move]] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Element [Move]]
[Move]
ttMove) [Move]
winningCaptures
      [Move] -> [Move] -> [Move]
forall a. Semigroup a => a -> a -> a
<> (Element [Move] -> Bool) -> [Move] -> [Move]
forall seq. IsSequence seq => (Element seq -> Bool) -> seq -> seq
filter (Element [Move] -> [Element [Move]] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Element [Move]]
[Move]
ttMove) [Move]
killerMoves

    worstMoves :: [Move]
worstMoves =
         (Element [Move] -> Bool) -> [Move] -> [Move]
forall seq. IsSequence seq => (Element seq -> Bool) -> seq -> seq
filter (Element [Move] -> [Element [Move]] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` ([Move]
ttMove [Move] -> [Move] -> [Move]
forall a. Semigroup a => a -> a -> a
<> [Move]
killerMoves)) [Move]
quietMoves
      [Move] -> [Move] -> [Move]
forall a. Semigroup a => a -> a -> a
<> (Element [Move] -> Bool) -> [Move] -> [Move]
forall seq. IsSequence seq => (Element seq -> Bool) -> seq -> seq
filter (Element [Move] -> [Element [Move]] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Element [Move]]
[Move]
ttMove)                  [Move]
losingCaptures

  ([Move], [Move]) -> IO ([Move], [Move])
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure if Depth
depth Depth -> Depth -> Bool
forall a. Ord a => a -> a -> Bool
>= Depth
3 Bool -> Bool -> Bool
&& Bool -> Bool
not (Position -> Bool
isKingInCheck Position
pos)
    then ([Move]
bestMoves, [Move]
worstMoves)
    else ([Move]
bestMoves [Move] -> [Move] -> [Move]
forall a. Semigroup a => a -> a -> a
<> [Move]
worstMoves, [])
  where
    quietMoves :: [Move]
quietMoves                        = Depth -> Position -> [Move]
getSortedQuietMoves Depth
depth Position
pos
    ([Move]
winningCaptures, [Move]
losingCaptures) = Position -> ([Move], [Move])
getSortedCaptures Position
pos


getKillers
  :: (?killersTable :: KillersTable) => Ply -> Position -> IO [Move]
getKillers :: (?killersTable::KillersTable) => Depth -> Position -> IO [Move]
getKillers !Depth
ply Position
pos =
   (Element [Move] -> Bool) -> [Move] -> [Move]
forall seq. IsSequence seq => (Element seq -> Bool) -> seq -> seq
filter (Move -> Position -> Bool
`isLegalQuietMove` Position
pos) ([Move] -> [Move]) -> IO [Move] -> IO [Move]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (?killersTable::KillersTable) => Depth -> IO [Move]
Depth -> IO [Move]
KillersTable.lookupMoves Depth
ply


getSortedCaptures :: Position -> ([Move], [Move])
getSortedCaptures :: Position -> ([Move], [Move])
getSortedCaptures Position
pos =
  ([(Move, Score)] -> [Move])
-> ([(Move, Score)] -> [Move])
-> ([(Move, Score)], [(Move, Score)])
-> ([Move], [Move])
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [(Move, Score)] -> [Move]
mapFn [(Move, Score)] -> [Move]
mapFn
    (([(Move, Score)], [(Move, Score)]) -> ([Move], [Move]))
-> ([(Move, Score)], [(Move, Score)]) -> ([Move], [Move])
forall a b. (a -> b) -> a -> b
$ (Element [(Move, Score)] -> Bool)
-> [(Move, Score)] -> ([(Move, Score)], [(Move, Score)])
forall seq.
IsSequence seq =>
(Element seq -> Bool) -> seq -> (seq, seq)
partition ((Score -> Score -> Bool
forall a. Ord a => a -> a -> Bool
>= Score
0) (Score -> Bool)
-> ((Move, Score) -> Score) -> (Move, Score) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Move, Score) -> Score
forall a b. (a, b) -> b
snd)
    ([(Move, Score)] -> ([(Move, Score)], [(Move, Score)]))
-> [(Move, Score)] -> ([(Move, Score)], [(Move, Score)])
forall a b. (a -> b) -> a -> b
$ (Move -> (Move, Score)) -> [Move] -> [(Move, Score)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map Move -> (Move, Score)
attachEval
    ([Move] -> [(Move, Score)]) -> [Move] -> [(Move, Score)]
forall a b. (a -> b) -> a -> b
$ (Element [Move] -> Bool) -> [Move] -> [Move]
forall seq. IsSequence seq => (Element seq -> Bool) -> seq -> seq
filter ((ContainerKey (Set Promotion) -> Set Promotion -> Bool
forall set. SetContainer set => ContainerKey set -> set -> Bool
`member` Set Promotion
bestPromotions) (Promotion -> Bool) -> (Move -> Promotion) -> Move -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (.promotion))
    ([Move] -> [Move]) -> [Move] -> [Move]
forall a b. (a -> b) -> a -> b
$ Position -> [Move]
allCaptures Position
pos
  where
    mapFn :: [(Move, Score)] -> [Move]
mapFn         = ((Move, Score) -> Move) -> [(Move, Score)] -> [Move]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
map (Move, Score) -> Move
forall a b. (a, b) -> a
fst ([(Move, Score)] -> [Move])
-> ([(Move, Score)] -> [(Move, Score)])
-> [(Move, Score)]
-> [Move]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Element [(Move, Score)] -> Down Score)
-> [(Move, Score)] -> [(Move, Score)]
forall o seq.
(Ord o, SemiSequence seq) =>
(Element seq -> o) -> seq -> seq
sortOn (Score -> Down Score
forall a. a -> Down a
Down (Score -> Down Score)
-> (Element [(Move, Score)] -> Score)
-> Element [(Move, Score)]
-> Down Score
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Move, Score) -> Score
Element [(Move, Score)] -> Score
forall a b. (a, b) -> b
snd)
    attachEval :: Move -> (Move, Score)
attachEval Move
mv = (Move
mv, Move -> Position -> Score
evaluateExchange Move
mv Position
pos)


getSortedQuietMoves :: Depth -> Position -> [Move]
getSortedQuietMoves :: Depth -> Position -> [Move]
getSortedQuietMoves !Depth
depth Position
pos
  | Depth
depth Depth -> Depth -> Bool
forall a. Ord a => a -> a -> Bool
<= Depth
2 = [Move]
quietMoves
  | Bool
otherwise = (Element [Move] -> Down Score) -> [Move] -> [Move]
forall o seq.
(Ord o, SemiSequence seq) =>
(Element seq -> o) -> seq -> seq
sortOn (Score -> Down Score
forall a. a -> Down a
Down (Score -> Down Score) -> (Move -> Score) -> Move -> Down Score
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Move -> Score
eval) [Move]
quietMoves
   where
    eval :: Move -> Score
eval Move
mv         = - Position -> Score
evaluatePosition (Move -> Position -> Position
makeMove Move
mv Position
pos)
    quietMoves :: [Move]
quietMoves      = Position -> [Move]
allQuietMoves Position
pos