{-# LANGUAGE NamedFieldPuns, RecordWildCards #-}

module Control.Concurrent.Annealer.Population (Population, offerState, initPop, pickState, getBest) where

import Control.Concurrent
import Control.Monad.Random (MonadRandom (..))

import Data.List
import Data.Ord
import Data.Functor

data Population s e = Pop {
	popSize :: {-# UNPACK #-} !Int,
	solScore :: (s -> e),
	solsVar :: MVar [(s, e)],
	solsChan :: Chan s}

offerState :: s -> Population s e -> IO ()
offerState s Pop{solsChan} = writeChan solsChan s

processChan :: Ord e => Int -> Population s e -> IO ()
processChan t pop@Pop{..} = t `seq` do
	sol <- readChan solsChan
	sols <- takeMVar solsVar
	if length sols < popSize then
		do	putMVar solsVar ((sol, solScore sol):sols)
			processChan t pop
		else do	let r = sqrt (log (fromIntegral t) + 1.0) :: Double
			sols' <- sortBy (flip $ comparing snd) <$> shuffle (popSize + 1) ((sol, solScore sol):sols)
			putMVar solsVar =<< elimListTo (length sols') popSize r sols'
			processChan (t+1) pop

-- | Eliminates elements of the list until it reaches a certain size.  
-- The probability that the ith element will be deleted is a geometric 
-- function of i.
elimListTo :: MonadRandom m => Int -> Int -> Double -> [a] -> m [a]
elimListTo n m r xs
	| n == m	= return xs
	| otherwise	= elimListTo (n-1) m r =<< elimList n r xs

pickState :: Population s e -> IO s
pickState Pop{..} = do
	sols <- readMVar solsVar
	i <- getRandomR (0, popSize - 1)
	return (fst (sols !! i))

getBest :: Ord e => Population s e -> IO s
getBest Pop{solsVar} = do
	sols <- readMVar solsVar
	return (fst (minimumBy (comparing snd) sols))

initPop :: Ord e => (s -> e) -> [s] -> Int -> IO (Population s e)
initPop solScore sols popSize = do
	solsVar <- newMVar [(sol, solScore sol) | sol <- sols]
	solsChan <- newChan
	let pop = Pop{..}
	forkIO (processChan 0 pop)
	return pop

{-# SPECIALIZE elimList :: Int -> Double -> [a] -> IO [a] #-}
elimList :: MonadRandom m => Int -> Double -> [a] -> m [a]
elimList _ _ [] = return []
elimList n r as = n `seq` do
	x <- getRandom
	let i = min (n-1) $ floor $ logBase r (1 - (x * (1 - r ^ (n + 1))))
	let (xs1, _:xs2) = splitAt i as
	return (xs1 ++ xs2)
	where	q = r ^ (n + 1)

{-# SPECIALIZE shuffle :: Int -> [a] -> IO [a] #-}
shuffle :: MonadRandom m => Int -> [a] -> m [a]
shuffle n xs = n `seq` case xs of
	[] -> return []
	xs -> do
		i <- getRandomR (0, n-1)
		let (xs1, x:xs2) = splitAt i xs
		xs' <- shuffle (n-1) (xs1 ++ xs2)
		return (x:xs')

{-# RULES
	"[] ++" forall xs . [] ++ xs = xs;
	#-}