module Majority.Rank where

import Data.Bool
import Data.Eq (Eq(..))
import Data.Foldable (Foldable(..))
import Data.Function (($))
import Data.Functor ((<$>))
import Data.Ord (Ord(..))
import Data.Ratio
import Data.Semigroup (Semigroup(..))
import Prelude (Integer, Integral(..), Num(..), RealFrac(..), undefined)
import Text.Show (Show(..))
import qualified Data.List as List

import Majority.Merit hiding (merit)
import Majority.Value

-- * Convenient type aliases
-- | Number of judges.
type JS = Integer
-- | Number of grades.
type GS = Integer
-- | Rank of a 'MajorityValue'.
type Rank = Integer

-- ** Type 'Median'
-- | A median.
-- First 'G' (lower median) is lower or equal
-- to the second 'G' (higher median).
newtype Median = Median (G,G)
 deriving (Eq, Show)

-- | 'Median' constructor enforcing its invariant.
median :: G -> G -> Median
median l h | l <= h = Median (l,h)
           | otherwise = undefined

-- * Ranking and unranking 'MajorityValue's

-- | @('rankOfMajorityValue' gs mv)@ returns
-- the number of possible 'MajorityValue's lower than given 'mv'.
--
-- @
-- 'rankOfMajorityValue' gs . 'majorityValueOfRank' js gs
--  '<$>' [0..'lastRank' js gs] == [0..'lastRank' js gs]
-- @
rankOfMajorityValue :: GS -> MajorityValue (Ranked grade) -> Rank
rankOfMajorityValue gs mv =
        go ((2 *) $ sum $ middleShare <$> mvN) 0 mvN
        where
        MajorityValue mvN = normalizeMajorityValue mv
        go :: Rational -> G -> [Middle (Ranked grade)] -> Rank
        go _n _previousHigh [] = 0
        go n previousHigh (Middle s low high : ms)
         -- Skip empty Middle.
         | s <= 0 = go n previousHigh ms
         -- Add the number of possible 'MajorityValue's
         -- before the two middle judgments of the current 'Middle',
         -- and recurse.
         | otherwise =
                countMediansBefore (numerator n) gs previousHigh (Median (rank low, rank high)) +
                go (n - dn) (rank high) (Middle (s - dn * (1%2)) low high : ms)
                where dn = if denominator s == 1 then 2 else 1

-- | The inverse of 'rankOfMajorityValue'.
--
-- @
-- 'majorityValueOfRank' js gs . 'rankOfMajorityValue' gs == 'id'
-- @
majorityValueOfRank :: JS -> GS -> Rank -> MajorityValue (Ranked ())
majorityValueOfRank js0 gs rk
 | not (0<=rk && rk<=lastRank js0 gs) = undefined
        -- error $ "rank="<>show rk<>" but lastRank "<>show js0<>" "<>show gs<>"="<>show (lastRank js0 gs)
 | otherwise = MajorityValue $ go 0 js0 rk
        where
        go previousHigh js r
         | js <= 0 = []
         | otherwise =
                let ms   = listMediansBefore js gs previousHigh (Median (gs,gs)) in
                let skip = List.takeWhile (<= r) $ List.scanl1 (+) $ countMedian js gs <$> ms in
                let dr   = if null skip then 0 else List.last skip in
                let dj   = if js`mod`2 == 0 then 2 else 1 in
                let Median (l,h) = List.head $ List.drop (length skip) ms in
                -- trace ("majorityValueOfRank: js="<>show js<>" r="<>show r<>" dr="<>show dr<>" "<>show (l,h)) $
                case go h (js - dj) (r - dr) of
                 -- Merge the 'Middle's which have the same 'Median' grades,
                 -- by adding their 'Share'.
                 Middle s rl1@(Ranked (l1, ())) rh1@(Ranked (h1, ())) : mv
                  | l1 == l && h1 == h -> Middle (dj%2 + s) rl1 rh1 : mv
                 mv -> Middle (dj%2) (Ranked (l,())) (Ranked (h,())) : mv

positionOfMajorityValue :: GS -> MajorityValue (Ranked grade) -> Rational
positionOfMajorityValue gs mv =
        rankOfMajorityValue gs mv %
        countMerits (numerator js) gs
        where js = (2 *) $ sum $ middleShare <$> unMajorityValue mv


-- ** Counting 'Merit's

-- | @('countMerits' js gs)@
-- returns the number of possible 'Merit's of size 'js' using grades 'gs'.
-- That is the number of ways to divide a segment of length 'js'
-- into at most 'gs' segments whose size is between '0' and 'js'.
--
-- The formula is: @(js+gs-1)·(js+gs-2)·…·(js+1)·js / (gs-1)·(gs-2)·…·2·1@
-- which is: @(js+gs-1)`nCk`(gs-1)@
countMerits :: JS -> GS -> Integer
countMerits js gs =
        -- debug ("countMerits: js="<>show js<>" gs="<>show gs) $
        (js+gs-1)`nCk`(gs-1)

-- | @('lastRank' js gs)@ returns the rank of the 'MajorityValue'
-- composed of 'js' times the highest grade of 'gs'.
--
-- @'lastRank' js gs == 'countMerits' js gs - 1@.
lastRank :: JS -> GS -> Rank
lastRank js gs = countMerits js gs - 1

-- ** Counting 'Median's

-- | @('countMedian' js gs ('Median' (l,h)))@
-- returns the number of possible 'Merit's of length 'js' using grades 'gs',
-- which have @(l,h)@ as lower and upper median grades.
-- This is done by multiplying together
-- the 'countMerits' to the left of 'l'
-- and the 'countMerits' to the right of 'h'.
countMedian :: JS -> GS -> Median -> Integer
countMedian js gs (Median (l,h)) =
        -- debug ("countMedian: js="<>show js<>" gs="<>show gs<>" (l,h)="<>show (l,h)) $
        countMerits side (l+1) * -- NOTE: +1 because 'l' starts at 0
        countMerits side (gs-h)
        where side = floor ((js-1)%2)

-- | @('countMediansBefore' js gs previousHigh ('Median' (low,high)))@
-- returns the number of possible 'Merit's with 'js' judges and 'gs' grades,
-- whose @'Median' (l,h)@ is such that @((l,h) < (low, high))@
-- and @(previousHigh <= h)@.
countMediansBefore :: JS -> GS -> G -> Median -> Integer
countMediansBefore js gs previousHigh lh =
        sum $ countMedian js gs <$> listMediansBefore js gs previousHigh lh

-- | @('listMediansBefore' js gs previousHigh ('Median' (low,high)))@
-- returns the 'Median's of possible 'Merit's with 'js' judges and 'gs' grades
-- with a 'Median' strictly lower than @(low,high)@.
listMediansBefore :: JS -> GS -> G -> Median -> [Median]
listMediansBefore js gs previousHigh (Median (l1,h1))
 | js`mod`2 == 0 = evenBegin<>even<>evenEnd
 | otherwise = odd
        where
        l0 = 0
        -- | Walk from the low initial 'l0' upto the low target 'l1'.
        odd = [ Median (l,l) | l<-[l0..l1-1] ]
        -- | Walk from the low initial 'l0', upto:
        -- - the highest (gs-1) if 'l0' is not the low target 'l1',
        -- - or the high target (h1-1) otherwise.
        evenBegin =
         [ Median (l,h)
         | l<-[l0]
         , h<-[{-l`max`-}previousHigh..(if l0<l1 then gs-1 else h1-1)]
          -- NOTE: useless (max l) since 'l' equals l0',
          -- which is always lower than or equal to 'previousHigh'.
         ]
        -- | Walk from the grade after the low initial (l0+1) upto
        -- the grade before the low target (l1-1)
        -- while the high 'h' is walking
        -- from the max of the minimal high and the current low,
        -- to the highest (gs-1).
        -- Beware that when recursing by removing a Middle,
        -- the minimal high is not the low initial,
        -- but the high of the lastly removed Middle.
        even =
         [ Median (l,h)
         | l<-[l0+1..l1-1]
         , h<-[l`max`previousHigh..gs-1]
         ]
        -- | Walk from the low target (if it hasn't been done yet)
        -- to the high target instead of the highest grade.
        evenEnd =
         [ Median (l,h)
         | l<-[l1 | l0 < l1]
         , h<-[l`max`previousHigh..h1-1]
         ]

-- | @('probaMedian' js gs)@ compute the probability
-- of each grade to be a 'MajorityGrade' given 'js' judges and 'gs' grades.
probaMedian :: JS -> GS -> [Rational]
probaMedian js gs =
        [ countMedian js gs (Median (l,l)) % d
        | l <- [0..gs-1]
        ] where d = countMerits js gs

-- ** Utils
-- | @('nCk' n k)@ returns the binomial coefficient of 'n' and 'k',
-- that is number of combinations of size 'k' from a set of size 'n'.
--
-- Computed using the formula:
-- @'nCk' n (k+1) == 'nCk' n (k-1) * (n-k+1) / k@
nCk :: Integral i => i -> i -> i
n`nCk`k | n<0||k<0||n<k = undefined
        | otherwise     = go 1 1
        where
        go i acc = if k' < i then acc else go (i+1) (acc * (n-i+1) `div` i)
        -- Use a symmetry to compute over smaller numbers,
        -- which is more efficient and safer
        k' = if n`div`2 < k then n-k else k
infix 7 `nCk`