{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- |GA, a Haskell library for working with genetic algoritms
--
-- Aug. 2011, by Kenneth Hoste
--
-- version: 0.1
module GA (Entity(..), 
           GAConfig(..), 
           ShowEntity(..), 
           evolve) where

import Data.List (intersperse, sortBy, nub)
import Data.Maybe (fromJust, isJust)
import Data.Ord (comparing)
import Debug.Trace (trace)
import System.Directory (createDirectoryIfMissing, doesFileExist)
import System.Random (StdGen, mkStdGen, randoms)

-- DEBUGGING

-- |Enable/disable debugging output (hard coded).
debug :: Bool
debug = False

-- |Return value with debugging output if debugging is enabled.
dbg :: String -> a -> a
dbg str x = if debug
                then trace str x
                else x 

-- |Currify a list of elements into tuples.
currify :: [a] -> [(a,a)]
currify (x:y:xs) = (x,y):currify xs
currify [] = []
currify [_] = error "(currify) ERROR: only one element left?!?"

-- |Take and drop elements of a list in a single pass.
takeAndDrop :: Int -> [a] -> ([a],[a])
takeAndDrop n xs
        | n > 0     = let (hs,ts) = takeAndDrop (n-1) (tail xs) in (head xs:hs, ts)
        | otherwise = ([],xs)


-- |Configuration for genetic algorithm.
data GAConfig = GAConfig {
                -- |population size
                popSize :: Int, 
                -- |size of archive (best entities so far)
                archiveSize :: Int, 
                -- |maximum number of generations to evolve
                maxGenerations :: Int, 
                -- |fraction of entities generated by crossover (tip: >= 0.80)
                crossoverRate :: Float, 
                -- |fraction of entities generated by mutation (tip: <= 0.20)
                mutationRate :: Float, 
                -- |parameter for crossover (semantics depend on actual crossover operator)
                crossoverParam :: Float, 
                -- |parameter for mutation (semantics depend on actual mutation operator)
                mutationParam :: Float, 
                -- |enable/disable built-in checkpointing mechanism
                withCheckpointing :: Bool 
                }

-- |Type class for entities that represent a candidate solution.
--
-- Three parameters:
--
-- * data structure representing an entity (a)
--
-- * data used to score an entity, e.g. a list of numbers (b)
--
-- * some kind of pool used to generate random entities, e.g. a Hoogle database (c)
--
class (Eq a, Read a, Show a, ShowEntity a) => Entity a b c | a -> b, a -> c where
  -- |Generate a random entity.
  genRandom :: c -> Int -> a
  -- |Crossover operator: combine two entities into a new entity.
  crossover :: c -> Float -> Int -> a -> a -> Maybe a
  -- |Mutation operator: mutate an entity into a new entity.
  mutation :: c -> Float -> Int -> a -> Maybe a
  -- |Score an entity (lower is better).
  score :: a -> b -> Double

-- |A possibly scored entity.
type ScoredEntity a = (Maybe Double, a)

-- |Scored generation (population and archive).
type ScoredGen a = ([ScoredEntity a],[ScoredEntity a])

-- |Type class for pretty printing an entity instead of just using the default show implementation.
class ShowEntity a where
  -- |Show an entity.
  showEntity :: a -> String

-- |Show a scored entity.
showScoredEntity :: ShowEntity a => ScoredEntity a -> String
showScoredEntity (score,e) = "(" ++ show score ++ ", " ++ showEntity e ++ ")"

-- |Show a list of scored entities.
showScoredEntities :: ShowEntity a => [ScoredEntity a] -> String
showScoredEntities es = ("["++) . (++"]") . concat . intersperse "," $ map showScoredEntity es

-- |Initialize: generate initial population.
initPop :: (Entity a b c) => c -> Int -> [Int] -> ([Int],[a])
initPop src n seeds = (seeds'', entities)
  where
    (seeds',seeds'')  = takeAndDrop n seeds
    entities = map (genRandom src) seeds'

-- |Score an entity (if it hasn't been already).
scoreEnt :: (Entity a b c) => b -> ScoredEntity a -> ScoredEntity a
scoreEnt d e@(Just _,_) = e
scoreEnt d (Nothing,x) = (Just $ score x d, x)

-- |Binary tournament selection operator.
tournamentSelection :: [ScoredEntity a] -> Int -> a
tournamentSelection xs seed = if s1 < s2 then x1 else x2
  where
    len = length xs
    g = mkStdGen seed
    is = take 2 $ map (flip mod len) $ randoms g
    [(s1,x1),(s2,x2)] = map ((!!) xs) is

-- |Function to perform a single evolution step:
--
-- * score all entities
--
-- * combine with best entities so far
--
-- * sort by fitness
--
-- * create new population using crossover/mutation
evolutionStep :: (Entity a b c) => c -> b -> (Int,Int,Int) -> (Float,Float) -> ScoredGen a -> (Int,Int) -> ScoredGen a
evolutionStep src d (cn,mn,an) (crossPar,mutPar) (pop,archive) (gi,seed) = dbg (   "\n\ngeneration " ++ (show gi) ++ ": \n\n" 
                                                                              ++ "  scored population: " ++ (showScoredEntities scoredPop) ++ "\n\n"
                                                                              ++ "  archive: " ++ (showScoredEntities archive') ++ "\n\n"
                                                                              ++ "  archive fitnesses: " ++ (show $ map fst archive') ++ "\n\n"
                                                                              ++ "  generated " ++ show (length pop') ++ " entities\n\n"
                                                                              ++ (replicate 150 '='))
                                                                       (pop',archive')
  where
    -- score population
    scoredPop = map (scoreEnt d) pop
    -- combine with archive for selection
    combo = scoredPop ++ archive
    -- split seeds for crossover selection/seeds, mutation selection/seeds
    seeds = randoms (mkStdGen seed) :: [Int]
    -- generate twice as many crossover/mutation entities as needed, because crossover/mutation can fail
    (crossSelSeeds,seeds')   = takeAndDrop (2*2*cn) seeds
    (crossSeeds   ,seeds'')  = takeAndDrop (2*cn) seeds'
    (mutSelSeeds  ,seeds''') = takeAndDrop (2*mn) seeds''
    (mutSeeds     ,_)        = takeAndDrop (2*mn) seeds'''
    -- crossover entities
    crossSel = currify $ map (tournamentSelection combo) crossSelSeeds
    crossEnts = take cn $ map fromJust $ filter isJust $ zipWith ($) (map (uncurry . (crossover src crossPar)) crossSeeds) crossSel
    -- mutation entities
    mutSel = map (tournamentSelection combo) mutSelSeeds
    mutEnts = take cn $ map fromJust $ filter isJust $ zipWith ($) (map (mutation src mutPar) mutSeeds) mutSel
    -- new population: crossovered + mutated entities
    pop' = zip (repeat Nothing) $ crossEnts ++ mutEnts
    -- new archive: best entities so far
    archive' = take an $ nub $ sortBy (comparing fst) $ filter (isJust . fst) combo

-- |Generate file name for checkpoint.
chkptFileName :: GAConfig -> (Int,Int) -> FilePath
chkptFileName cfg (gi,seed) = dbg fn fn
  where
    cfgTxt = (show $ popSize cfg) ++ "-" ++ 
             (show $ archiveSize cfg) ++ "-" ++
             (show $ crossoverRate cfg) ++ "-" ++
             (show $ mutationRate cfg) ++ "-" ++
             (show $ crossoverParam cfg) ++ "-" ++
             (show $ mutationParam cfg)
    fn = "checkpoints/GA-" ++ cfgTxt ++ "-gen" ++ (show gi) ++ "-seed-" ++ (show seed) ++ ".chk"

-- |Try to restore from checkpoint: first checkpoint for which a checkpoint file is found is restored.
restoreFromCheckpoint :: (Entity a b c) => GAConfig -> [(Int,Int)] -> IO (Maybe (Int,ScoredGen a))
restoreFromCheckpoint cfg ((gi,seed):genSeeds) = do
                                                  chkptFound <- doesFileExist fn
                                                  if chkptFound 
                                                    then do
                                                          txt <- dbg ("chk for gen. " ++ (show gi) ++ " found") readFile fn
                                                          return $ Just (gi, read txt)
                                                    else restoreFromCheckpoint cfg genSeeds
  where
    fn = chkptFileName cfg (gi,seed)
restoreFromCheckpoint cfg [] = return Nothing

-- |Checkpoint a single generation.
checkpointGen :: (Entity a b c) => GAConfig -> Int -> Int -> ScoredGen a -> IO()
checkpointGen cfg index seed (pop,archive) = do
                                           let txt = show $ (pop,archive)
                                               fn = chkptFileName cfg (index,seed)
                                           if debug 
                                              then putStrLn $ "writing checkpoint for gen " ++ (show index) ++ " to " ++ fn
                                              else return ()
                                           createDirectoryIfMissing True "checkpoints"
                                           writeFile fn txt

-- |Evolution: evaluate generation, (maybe) checkpoint, continue.
evolution :: (Entity a b c) => GAConfig -> ScoredGen a -> (ScoredGen a -> (Int,Int) -> ScoredGen a) -> [(Int,Int)] -> IO (ScoredGen a)
evolution cfg (pop,archive) step ((gi,seed):gss) = do
                                             let newPa@(_,archive') = step (pop,archive) (gi,seed)
                                                 (Just fitness, e) = head archive'
                                             -- checkpoint generation if desired
                                             if (withCheckpointing cfg)
                                               then checkpointGen cfg gi seed newPa
                                               else return () -- skip checkpoint
                                             putStrLn $ "best entity (gen. " ++ show gi ++ "): " ++ (show e) ++ " [fitness: " ++ show fitness ++ "]"
                                             -- check for perfect entity
                                             if (fromJust $ fst $ head archive') == 0.0
                                                then do 
                                                        putStrLn $ "perfect entity found, finished after " ++ show gi ++ " generations!"
                                                        return newPa
                                                else evolution cfg newPa step gss
-- no more gen. indices/seeds => quit
evolution cfg (pop,archive) _              []    = do 
                                                      putStrLn $ "done evolving!"
                                                      return (pop,archive)
 
-- |Do the evolution!
evolve :: (Entity a b c) => StdGen -> GAConfig -> c -> b -> IO a
evolve g cfg src dataset = do
                -- generate list of random integers
                let rs = randoms g :: [Int]

                    -- initial population
                let (rs',pop) = initPop src (popSize cfg) rs

                let ps = popSize cfg
                    -- number of entities generated by crossover/mutation
                    cCnt = round $ (crossoverRate cfg) * (fromIntegral ps)
                    mCnt = round $ (mutationRate cfg) * (fromIntegral ps)
                    -- archive size
                    aSize = archiveSize cfg
                    -- crossover/mutation parameters
                    crossPar = crossoverParam cfg
                    mutPar = mutationParam cfg
                    --  seeds for evolution
                    seeds = take (maxGenerations cfg) rs'
                    -- seeds per generation
                    genSeeds = zip [0..] seeds
                    -- checkpoint?
                    checkpointing = withCheckpointing cfg
                    -- do the evolution
                restored <- if checkpointing
                               then restoreFromCheckpoint cfg (reverse genSeeds) :: (Entity a b c) => IO (Maybe (Int,ScoredGen a))
                               else return Nothing
                let (gi,(pop',archive')) = if isJust restored
                                          -- restored pop/archive from checkpoint
                                          then dbg ("restored from checkpoint!\n\n") $ fromJust restored 
                                          -- restore failed, new population and empty archive
                                          else dbg (if checkpointing then "no checkpoint found...\n\n"
                                                                       else "checkpoints ignored...\n\n") 
                                                         (-1, (zip (repeat Nothing) pop, []))
                (resPop,resArchive) <- evolution cfg (pop',archive') (evolutionStep src dataset (cCnt,mCnt,aSize) (crossPar,mutPar)) (filter ((>gi) . fst) genSeeds)
                
                if null resArchive
                  then error $ "(evolve) empty archive!"
                  else return $ snd $ head resArchive