module Search.Perft (perft, divide, allMoves) where

import           AppPrelude

import           Models.Move
import           Models.Position
import           MoveGen.MakeMove
import           MoveGen.PieceCaptures
import           MoveGen.PieceQuietMoves

import qualified Data.Map                as Map
import           Models.Score


perft :: Depth -> Position -> Int
perft :: Depth -> Position -> Int
perft = Int -> Depth -> Position -> Int
forall {a}. (Eq a, Num a) => Int -> a -> Position -> Int
go Int
0
  where
    go :: Int -> a -> Position -> Int
go !Int
nodes !a
depth !Position
pos
      | a
depth a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 = Int
nodes Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Move] -> Int
forall mono. MonoFoldable mono => mono -> Int
length [Move]
moves
      | Bool
otherwise = (Int -> Element [Move] -> Int) -> Int -> [Move] -> Int
forall mono a.
MonoFoldable mono =>
(a -> Element mono -> a) -> a -> mono -> a
foldl' Int -> Element [Move] -> Int
Int -> Move -> Int
f Int
nodes [Move]
moves
      where
        f :: Int -> Move -> Int
f !Int
acc Move
mv  = Int -> a -> Position -> Int
go Int
acc (a
depth a -> a -> a
forall a. Num a => a -> a -> a
- a
1) (Move -> Position -> Position
makeMove Move
mv Position
pos)
        moves :: [Move]
moves = Position -> [Move]
allMoves Position
pos


divide :: Depth -> Position -> Map Move Int
divide :: Depth -> Position -> Map Move Int
divide !Depth
depth !Position
pos
  | Depth
depth Depth -> Depth -> Bool
forall a. Eq a => a -> a -> Bool
== Depth
1 = [(Move, Int)] -> Map Move Int
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Move, Int)] -> Map Move Int) -> [(Move, Int)] -> Map Move Int
forall a b. (a -> b) -> a -> b
$ [(Move, Int)] -> [Element [(Move, Int)]]
forall mono. MonoFoldable mono => mono -> [Element mono]
toList ((,Int
1) (Move -> (Move, Int)) -> [Move] -> [(Move, Int)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Move]
moves)
  | Bool
otherwise = [(Move, Int)] -> Map Move Int
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
  ((Position -> Int) -> (Move, Position) -> (Move, Int)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Depth -> Position -> Int
perft (Depth
depth Depth -> Depth -> Depth
forall a. Num a => a -> a -> a
- Depth
1)) ((Move, Position) -> (Move, Int))
-> (Move -> (Move, Position)) -> Move -> (Move, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Move -> (Move, Position)
getResults (Move -> (Move, Int)) -> [Move] -> [(Move, Int)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Move] -> [Element [Move]]
forall mono. MonoFoldable mono => mono -> [Element mono]
toList [Move]
moves)
  where
    getResults :: Move -> (Move, Position)
getResults !Move
mv = (Move
mv, Move -> Position -> Position
makeMove Move
mv Position
pos)
    moves :: [Move]
moves = Position -> [Move]
allMoves Position
pos


allMoves :: Position -> [Move]
allMoves :: Position -> [Move]
allMoves = Position -> [Move]
allCaptures (Position -> [Move]) -> (Position -> [Move]) -> Position -> [Move]
forall a. Semigroup a => a -> a -> a
<> Position -> [Move]
allQuietMoves