{-# OPTIONS -fno-warn-tabs #-}
module Majority.Rank where
import Data.Bool
import Data.Eq (Eq(..))
import Data.Foldable (Foldable(..))
import Data.Function (($))
import Data.Functor ((<$>))
import Data.Ord (Ord(..))
import Data.Ratio
import Data.Semigroup (Semigroup(..))
import Prelude (Integer, Integral(..), Num(..), RealFrac(..), undefined)

import Majority.Merit hiding (merit)
import Majority.Value

-- | Number of judges.
type JS = Integer
-- | Number of grades.
type GS = Integer
-- | Rank of grade.
type G = Integer

-- | 'rankOfMajorityValue gs mv' returns
-- the number of 'MajorityValue' lower than given 'mv'.
rankOfMajorityValue :: GS -> MajorityValue (Ranked grade) -> Integer
rankOfMajorityValue gs mv =
        go ((2 *) $ sum $ middleShare <$> mvN) (0,0) mvN
        where
        MajorityValue mvN = normalizeMajorityValue mv
        go :: Rational -> (G,G) -> [Middle (Ranked grade)] -> Integer
        go _n _0 [] = 0
        go n (l0,h0) (Middle s low high : ms)
         | s <= 0 = go n (l0,h0) ms
         | otherwise =
                countMiddleFrom (numerator $ n) gs (l0,h0) (rank low, rank high) +
                go (n - dn) (0, rank high) (Middle (s - dn * (1%2)) low high : ms)
                where dn = if denominator s == 1 then 2 else 1

positionOfMajorityValue :: GS -> MajorityValue (Ranked grade) -> Rational
positionOfMajorityValue gs mv =
        rankOfMajorityValue gs mv %
        countMerits (2 * numerator js) gs
        where js = sum $ middleShare <$> unMajorityValue mv

countMiddleFrom :: JS -> GS -> (G,G) -> (G,G) -> Integer
countMiddleFrom js gs (l0,h0) (l1,h1) =
        sum $ countMiddle js gs <$>
                if js`mod`2 == 0 then even else odd
        where
        even = even1 <> even2 <> even3
        odd = [ (l,l) | l<-[l0..l1-1] ]
        even1 =
         [ (l,h) | l<-[l0]
         , h<-[h0..(if l0<l1 then gs-1 else h1-1)]
         ]
        even2 =
         [ (l,h) | l<-[l0+1..l1-1]
         , h<-[max l h0..gs-1]
         ]
        even3 =
         [ (l,h) | l<-[l1 | l0 < l1]
         , h<-[max l h0..h1-1]
         ]

-- | 'countMiddle js gs (l,h)'
-- returns the number of 'MajorityValue's of length 'js' and using grades 'gs',
-- which have '(l,h)' as lower and upper majority grade.
-- This is done by multiplying together
-- the 'countMerits' to the left of 'l'
-- and the 'countMerits' to the right of 'h'
countMiddle :: JS -> GS -> (G,G) -> Integer
countMiddle js gs (l,h) =
        -- debug ("countMiddle: js="<>show js<>" gs="<>show gs<>" (l,h)="<>show (l,h)) $
        countMerits side (l+1) * -- NOTE: +1 because 'l' starts at 0
        countMerits side (gs-h)
        where side = floor ((js-1)%2)

-- | (probaMajorityGrades js gs' compute the probability
-- of each grade to be a 'MajorityGrade' given 'js' judges and 'gs' grades.
probaMajorityGrades :: JS -> GS -> [Rational]
probaMajorityGrades js gs =
        [ countMiddle js gs (l,l) % d
        | l <- [0..gs-1]
        ] where d = countMerits js gs

-- | 'countMerits js gs'
-- returns the number of 'Merit's of size 'js' possible using grades 'gs'.
-- That is the number of ways to divide a segment of length 'js'
-- into at most 'gs' segments whose size is between '0' and 'js'.
countMerits :: JS -> GS -> Integer
countMerits js gs =
        -- debug ("countMerits: js="<>show js<>" gs="<>show gs) $
        (js+gs-1)`nCk`(gs-1)

lastRank :: JS -> GS -> Integer
lastRank js gs = countMerits js gs - 1

-- | @'nCk' n k@ returns the number of combinations of size 'k' from a set of size 'n'.
--
-- Computed using the formula:
-- @'nCk' n (k+1) == 'nCk' n (k-1) * (n-k+1) / k@
nCk :: Integral i => i -> i -> i
n`nCk`k | n<0||k<0||n<k = undefined
        | otherwise     = go 1 1
        where
        go i acc = if k' < i then acc else go (i+1) (acc * (n-i+1) `div` i)
        -- Use a symmetry to compute over smaller numbers,
        -- which is more efficient and safer
        k' = if n`div`2 < k then n-k else k
infix 7 `nCk`