module Data.Poker.Deck
    (
    -- * HandValue
        HandValue(..),
        NumericalHandValue(..),
        ConsecutiveHandValue(..),
        isNoPair,
        isOnePair,
        isTwoPair,
        isThreeOfAKind,
        isStraight,
        isFlush,
        isFullHouse,
        isFourOfAKind,
        isStraightFlush,
    -- * Cards
        Card(..),
        Rank(..),
        Suit(..),
        Kicker,
        mkCard,
        cardRank,
        cardSuit,
        rankIdentifiers,
        suitIdentifiers,
    -- * CardSets
        CardSet(..),
        toList,
        fromList,
        singleton,
        size,
        countRank,
        countSuit,
        member,
        empty,
        isEmpty,
        union,
        intersection,
        inverse
    ) where


import Data.Poker.Definitions


import System.Random
import Foreign.C
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector as V
import Data.List                                ( foldl' )
import Data.Int
import Data.Word
import Data.Bits
import Data.Ix
import Data.Char

data Rank = Two | Three | Four | Five | Six | Seven | Eight | Nine | Ten
          | Jack | Queen | King | Ace deriving (Show,Eq,Ord,Enum,Bounded)
data Suit = Hearts | Diamonds | Clubs | Spades deriving (Show,Eq,Ord,Enum,Bounded)
type Kicker = Rank

-- | This structure represents the value of a poker hand as a high-level ADT.
--
--   The following must be true for a HandValue to be valid:
--
--     * All kickers must be in decending order.
--
--     * No 'Rank' may not occur twice.
--
--     * The kickers may not construct higher-value hands.
--
--       For example, @'NoPair' 'Six' 'Five' 'Four' 'Three' 'Two'@
--       is not a valid HandValue.
data HandValue =
  NoPair        Kicker Kicker Kicker Kicker Kicker |
  OnePair       Rank Kicker Kicker Kicker |
  TwoPair       Rank Rank Kicker |
  ThreeOfAKind  Rank Kicker Kicker |
  Straight      Rank |
  Flush         Kicker Kicker Kicker Kicker Kicker |
  FullHouse     Rank Rank |
  FourOfAKind   Rank Kicker |
  StraightFlush Rank
  deriving (Show,Eq,Ord)

-- | Abstract representation of a card consisting of a 'Rank' and a 'Suit'
newtype Card = Card CInt deriving (Eq,Ord)

-- | A set of cards.
newtype CardSet = CardSet { unmask :: StdDeck_CardMask } deriving (Eq,Ord,Bounded)

-- | Isomorphic to 'HandValue' but computed much more efficiently.
--
--   If possible, this is the structure to use.
newtype NumericalHandValue = NumericalHandValue { unNumericalHandValue :: Word }
  deriving (Eq, Ord)

-- | Isomorphic to 'HandValue' but stored more efficiently.
--
--   This structure has the special property of being bounded and an enum. It is
--   especially useful as an Array index.
newtype ConsecutiveHandValue = ConsecutiveHandValue { unConsecutiveHandValue :: Int }
  deriving (Eq, Ord)

instance Bounded ConsecutiveHandValue where
  minBound = ConsecutiveHandValue 0
  maxBound = ConsecutiveHandValue 7461

instance Ix ConsecutiveHandValue where
  range (ConsecutiveHandValue a, ConsecutiveHandValue b) = map ConsecutiveHandValue (range (a,b))
  index (ConsecutiveHandValue a, ConsecutiveHandValue b) (ConsecutiveHandValue n) = index (a, b) n
  inRange (ConsecutiveHandValue a, ConsecutiveHandValue b) (ConsecutiveHandValue n) = inRange (a, b) n
  rangeSize (ConsecutiveHandValue a, ConsecutiveHandValue b) = rangeSize (a,b)


-- | True for all NoPair hands.
isNoPair :: HandValue -> Bool
isNoPair NoPair{} = True
isNoPair _        = False

-- | True for all OnePair hands.
isOnePair :: HandValue -> Bool
isOnePair OnePair{} = True
isOnePair _         = False

-- | True for all TwoPair hands.
isTwoPair :: HandValue -> Bool
isTwoPair TwoPair{} = True
isTwoPair _         = False

-- | True for all ThreeOfAKind hands.
isThreeOfAKind :: HandValue -> Bool
isThreeOfAKind ThreeOfAKind{} = True
isThreeOfAKind _              = False

-- | True for all Straight hands.
isStraight :: HandValue -> Bool
isStraight Straight{} = True
isStraight _          = False

-- | True for all Flush hands.
isFlush :: HandValue -> Bool
isFlush Flush{} = True
isFlush _       = False

-- | True for all FullHouse hands.
isFullHouse :: HandValue -> Bool
isFullHouse FullHouse{} = True
isFullHouse _           = False

-- | True for all FourOfAKind hands.
isFourOfAKind :: HandValue -> Bool
isFourOfAKind FourOfAKind{} = True
isFourOfAKind _             = False

-- | True for all StraightFlush hands.
isStraightFlush :: HandValue -> Bool
isStraightFlush StraightFlush{} = True
isStraightFlush _               = False





instance Bounded HandValue where
    minBound = NoPair Seven Five Four Three Two
    maxBound = StraightFlush Ace

instance Show Card where
    show = cardToString

instance Read Card where
    readsPrec i (c:cs) | isSpace c = readsPrec i cs
    readsPrec _ (r:s:rest) =
        case stringToCard_ [r,s] of
            Nothing   -> []
            Just card -> [(card,rest)]
    readsPrec _ _ = []

instance Bounded Card where
    minBound = Card 0
    maxBound = Card 51

instance Enum Card where
    succ card | card >= maxBound = error "Data.Poker.Deck.Card.succ: bad argument"
    succ (Card i) = Card (succ i)
    pred card | card <= minBound = error "Data.Poker.Deck.Card.pred: bad argument"
    pred (Card i) = Card (pred i)
    toEnum i | i < 0 || i > 51   = error "Data.Poker.Deck.Card.toEnum: bad argument"
    toEnum i = Card (fromIntegral i)
    fromEnum (Card i) = fromIntegral i
    enumFrom val = enumFromTo val maxBound
    enumFromThen val step = enumFromThenTo val step maxBound

instance Random Card where
    randomR (Card low,Card high) g
        = let (n, g') = randomR (fromIntegral low,fromIntegral high) g
          in (Card (fromIntegral (n :: Int)), g')
    random g = let (n, g') = randomR (0,51 :: Int) g
               in (Card (fromIntegral n), g')

instance Show CardSet where
    show = show . toList

instance Read CardSet where
    readsPrec i inp = do
        (lst,rest) <- readsPrec i inp
        return (fromList lst, rest)

instance Random CardSet where
    randomR (CardSet low, CardSet high) g =
        let (n, g') = randomR (fromIntegral low, fromIntegral high) g
        in (CardSet (fromIntegral (n::Int) .&. 2305596714850918399), g')
    random g =
        let (n, g') = randomR (0, maxBound) g
        in (CardSet (n .&. 2305596714850918399), g')

instance Random ConsecutiveHandValue where
    randomR (ConsecutiveHandValue low, ConsecutiveHandValue high) g =
        let (n, g') = randomR (low, high) g
        in (ConsecutiveHandValue n, g')
    random = randomR (minBound, maxBound)

instance Random Rank where
    randomR (low, high) g =
        let (n, g') = randomR (fromEnum low, fromEnum high) g
        in (toEnum  n, g')
    random = randomR (minBound, maxBound)

instance Random Suit where
    randomR (low, high) g =
        let (n, g') = randomR (fromEnum low, fromEnum high) g
        in (toEnum  n, g')
    random = randomR (minBound, maxBound)


-- | Construct a card with the given rank and suit.
mkCard :: Rank -> Suit -> Card
mkCard rank suit =
    Card $ fromIntegral $ fromEnum rank + fromEnum suit * 13

-- | Inspect the rank of a card.
cardRank :: Card -> Rank
cardRank (Card idx) = toEnum (fromIntegral idx `mod` 13)

-- | Inspect the suit of a card.
cardSuit :: Card -> Suit
cardSuit (Card idx) = toEnum (fromIntegral idx `div` 13)

stringToCard_ :: String -> Maybe Card
stringToCard_ [rankChar,suitChar] = do
    rank <- lookup (toUpper rankChar) (map swap rankIdentifiers)
    suit <- lookup suitChar (map swap suitIdentifiers)
    return $ mkCard rank suit
  where
    swap (a,b) = (b,a)
stringToCard_ _ = Nothing

cardToString :: Card -> String
cardToString card =
    [ rank, suit ]
  where
    Just rank = lookup (cardRank card) rankIdentifiers
    Just suit = lookup (cardSuit card) suitIdentifiers

rankIdentifiers :: [(Rank, Char)]
rankIdentifiers =
    [ (Two,   '2')
    , (Three, '3')
    , (Four,  '4')
    , (Five,  '5')
    , (Six,   '6')
    , (Seven, '7')
    , (Eight, '8')
    , (Nine,  '9')
    , (Ten,   'T')
    , (Jack,  'J')
    , (Queen, 'Q')
    , (King,  'K')
    , (Ace,   'A') ]

suitIdentifiers :: [(Suit, Char)]
suitIdentifiers =
    [ (Hearts,   'h')
    , (Diamonds, 'd')
    , (Clubs,    'c')
    , (Spades,   's')]

foreign import ccall unsafe "hs_StdDeck_MASK" c_getMASK :: CInt -> StdDeck_CardMask

-- | O(1). Create a singleton set.
singleton :: Card -> CardSet
singleton (Card idx) =
  CardSet (fromIntegral (VU.unsafeIndex cardSetVector (fromIntegral idx)))

-- Having the masks in a vector is faster than calling c_getMASK.
cardSetVector :: VU.Vector Int64
cardSetVector =
  VU.fromList (map (fromIntegral . c_getMASK) [0 .. 51])







foreign import ccall unsafe "hs_StdDeck_numCards" c_numCards :: StdDeck_CardMask -> CInt
-- TODO: Is this O(n) or O(1) ?
-- | O(n). The number of cards in the set.
--
--   Performance note: Try to avoid using this function in an inner loop.
size :: CardSet -> Int
size (CardSet m) = fromIntegral (c_numCards m)

maskOP :: (StdDeck_CardMask -> StdDeck_CardMask -> StdDeck_CardMask) -> CardSet -> CardSet -> CardSet
maskOP op (CardSet m1) (CardSet m2) = CardSet (op m1 m2)

maskUnOP :: (StdDeck_CardMask -> StdDeck_CardMask) -> CardSet -> CardSet
maskUnOP unop (CardSet m) = CardSet (unop m)

-- | O(1). The union of two sets.
union :: CardSet -> CardSet -> CardSet
union = maskOP (.|.)

-- | O(1). The intersection of two sets.
intersection :: CardSet -> CardSet -> CardSet
intersection = maskOP (.&.)

-- | O(1). Find the inverse set such that @set `intersection` inverse = empty@ and
--   @set `union` inverse set = fromList [minBound ..]@.
inverse :: CardSet -> CardSet
inverse = maskUnOP complement

-- | O(1). The empty set.
empty :: CardSet
empty = CardSet 0

-- | O(n). Create a set from a list of cards.
fromList :: [Card] -> CardSet
fromList = foldl' union empty . map singleton

-- | O(n). Convert the set to a list of cards.
toList :: CardSet -> [Card]
toList mask = filter (member mask) allCards

allCards :: [Card]
allCards = [ mkCard rank suit | rank <- [Ace, King .. Two], suit <- [minBound .. maxBound]]

-- | O(n). Count the number of cards with a specific 'Rank' in a set.
countRank :: CardSet -> Rank -> Int
countRank mask rank =
    size (mask `intersection` rankMask)
  where
    rankMask = rankMaskVector V.! fromEnum rank
    rankMaskVector :: V.Vector CardSet
    rankMaskVector = V.fromList [ fromList (map (mkCard eachRank) [minBound .. ]) | eachRank <- [minBound .. ] ]

-- | O(n). Count the number of cards with a specific 'Suit' in a set.
countSuit :: CardSet -> Suit -> Int
countSuit mask suit =
    size (mask `intersection` suitMask)
  where
    suitMask = suitMaskVector V.! fromEnum suit
    suitMaskVector :: V.Vector CardSet
    suitMaskVector = V.fromList [ fromList (map (flip mkCard eachSuit) [minBound .. ]) | eachSuit <- [minBound .. ] ]

-- | O(1). Is the card in the set?
member :: CardSet -> Card -> Bool
member mask card
    = not (isEmpty (mask `intersection` singleton card))

-- | O(1). Is this the empty set?
isEmpty :: CardSet -> Bool
isEmpty (CardSet 0) = True
isEmpty _            = False