{-# LANGUAGE BangPatterns #-}
module Data.Poker.Evaluate
    ( handValue
    , handValue_n
    , numericalHandValue
    , numericalHandValue_n
    , distHandValue
    ) where

import Data.Poker.Definitions
import Data.Poker.Interface
import Data.Poker.Deck
import Data.Poker.Brute
import Data.Poker.Enumerate

import Foreign.C

import Data.Array.MArray
import Data.Array.ST as ST
import Data.Array.Unboxed
import Data.Array.Base



foreign import ccall unsafe "hs_StdDeck_StdRules_EVAL_N" c_eval_n :: StdDeck_CardMask -> CInt -> CInt
-- | Find the strongest possible hand using the given cards.
--
--   It is significantly faster to compute a 'NumericalHandValue' than a
--   'HandValue'. Use this function instead of 'handValue' when possible.
numericalHandValue :: CardSet            -- ^ Available cards.
                   -> NumericalHandValue
numericalHandValue mask =
  numericalHandValue_n (size mask) mask

-- | Find the strongest possible hand using the given cards. This function is
--   significantly faster than 'numericalHandValue' if the size of the card set
--   is constant.
--
--   It is significantly faster to compute a 'NumericalHandValue' than a
--   'HandValue'. Use this function instead of 'handValue_n' when possible.
numericalHandValue_n :: Int               -- ^ Size of card set.
                     -> CardSet           -- ^ Available cards.
                     -> NumericalHandValue
numericalHandValue_n n (CardSet m)
    = NumericalHandValue $ fromIntegral (c_eval_n m (fromIntegral n))

-- | Find the strongest possible hand using the given cards.
handValue :: CardSet    -- ^ Available cards.
          -> HandValue
handValue = numericalToHandValue . numericalHandValue

-- | Find the strongest possible hand using the given cards. This function is
--   significantly faster than 'handValue' if the size of the card set
--   is constant.
handValue_n :: Int         -- ^ Size of card set.
            -> CardSet     -- ^ Available cards.
            -> HandValue
handValue_n n = numericalToHandValue . numericalHandValue_n n










-- Full:            70ms
-- plain eval:      45ms, 500 samples
-- loop:            10ms

-- brute cost:  70ms - 45ms  = 25ms
-- eval cost:   45ms - 10ms  = 35ms
distHandValue :: CardSet -> UArray ConsecutiveHandValue Int
distHandValue !mask = runSTUArray (do
  arr <- newArray (minBound, maxBound) 0
  let step community = do
        let !fastValue = numericalHandValue_n 7 (community `union` mask)
        let !(ConsecutiveHandValue idx) = numericalToConsecutive fastValue
        n <- unsafeRead arr idx
        unsafeWrite arr idx (n+1)
  enumerateFiveCards mask step
  return arr)