{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- |GA, a Haskell library for working with genetic algoritms
--
-- Aug. 2011 - Sept. 2011, by Kenneth Hoste
--
-- version: 0.2
module GA (Entity(..), 
           GAConfig(..), 
           evolve, 
           evolveVerbose,
           randomSearch) where

import Control.Monad (zipWithM)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.List (sortBy, nub)
import Data.Maybe (catMaybes, fromJust, isJust)
import Data.Ord (comparing)
import System.Directory (createDirectoryIfMissing, doesFileExist)
import System.Random (StdGen, mkStdGen, random, randoms)

-- |Currify a list of elements into tuples.
currify :: [a] -- ^ list
           -> [(a,a)] -- ^ list of tuples
currify (x:y:xs) = (x,y):currify xs
currify [] = []
currify [_] = error "(currify) ERROR: only one element left?!?"

-- |Take and drop elements of a list in a single pass.
takeAndDrop :: Int -- ^ number of elements to take/drop
            -> [a] -- ^ list 
            -> ([a],[a]) -- ^ result: taken list element and rest of list
takeAndDrop n xs
    | n > 0     = let (hs,ts) = takeAndDrop (n-1) (tail xs) 
                   in (head xs:hs, ts)
    | otherwise = ([],xs)

-- |Configuration for genetic algorithm.
data GAConfig = GAConfig {
    -- |population size
    getPopSize :: Int, 
    -- |size of archive (best entities so far)
    getArchiveSize :: Int, 
    -- |maximum number of generations to evolve
    getMaxGenerations :: Int, 
    -- |fraction of entities generated by crossover (tip: >= 0.80)
    getCrossoverRate :: Float, 
    -- |fraction of entities generated by mutation (tip: <= 0.20)
    getMutationRate :: Float, 
    -- |parameter for crossover (semantics depend on crossover operator)
    getCrossoverParam :: Float, 
    -- |parameter for mutation (semantics depend on mutation operator)
    getMutationParam :: Float, 
    -- |enable/disable built-in checkpointing mechanism
    getWithCheckpointing :: Bool,
    -- |rescore archive in each generation?
    getRescoreArchive :: Bool
                }

-- |Type class for entities that represent a candidate solution.
--
-- Five parameters:
--
-- * data structure representing an entity (e)
--
-- * score type (s), e.g. Double
--
-- * data used to score an entity, e.g. a list of numbers (d)
--
-- * some kind of pool used to generate random entities, 
--   e.g. a Hoogle database (p)
--
-- * monad to operate in (m)
--
-- Minimal implementation includes genRandom, crossover, mutation, 
-- and either score', score or scorePop.
--
class (Eq e, Read e, Show e, 
       Ord s, Read s, Show s, 
       Monad m)
   => Entity e s d p m 
    | e -> s, e -> d, e -> p, e -> m where

  -- |Generate a random entity. [required]
  genRandom :: p -- ^ pool for generating random entities
            -> Int -- ^ random seed
            -> m e -- ^ random entity

  -- |Crossover operator: combine two entities into a new entity. [required]
  crossover :: p -- ^ entity pool
            -> Float -- ^ crossover parameter
            -> Int -- ^ random seed
            -> e -- ^ first entity
            -> e -- ^ second entity
            -> m (Maybe e) -- ^ entity resulting from crossover

  -- |Mutation operator: mutate an entity into a new entity. [required]
  mutation :: p -- ^ entity pool
           -> Float -- ^ mutation parameter
           -> Int -- ^ random seed
           -> e -- ^ entity to mutate
           -> m (Maybe e) -- ^ mutated entity

  -- |Score an entity (lower is better), pure version. [optional]
  --
  -- Overridden if score or scorePop are implemented.
  score' :: d -- ^ dataset for scoring entities
         -> e -- ^ entity to score
         -> (Maybe s) -- ^ entity score
  score' _ _ = error $ "(GA) score' is not defined, "
                    ++ "nor is score or scorePop!"

  -- |Score an entity (lower is better), monadic version. [optional]
  --
  -- Default implementation hoists score' into monad, 
  -- overriden if scorePop is implemented.
  score :: d -- ^ dataset for scoring entities
        -> e -- ^ entity to score
        -> m (Maybe s) -- ^ entity score
  score d e = do 
                 return $ score' d e

  -- |Score an entire population of entites. [optional]
  --
  -- Default implementation returns Nothing, 
  -- and triggers indivual of entities.
  scorePop :: d -- ^ dataset to score entities
           -> [e] -- ^ universe of known entities
           -> [e] -- ^ population of entities to score
           -> m (Maybe [Maybe s]) -- ^ scores for population entities
  scorePop _ _ _ = return Nothing

  -- |Determines whether a score indicates a perfect entity. [optional]
  --
  -- Default implementation returns always False.
  isPerfect :: (e,s) -- ^ scored entity
               -> Bool -- ^ whether or not scored entity is perfect
  isPerfect _ = False


-- |A possibly scored entity.
type ScoredEntity e s = (Maybe s, e)

-- |Scored generation (population and archive).
type Generation e s = ([e],[ScoredEntity e s])

-- |Universe of entities.
type Universe e = [e]

-- |Initialize: generate initial population.
initPop :: (Entity e s d p m) => p -- ^ pool for generating random entities
                            -> Int -- ^ population size
                            -> Int -- ^ random seed
                            -> m [e] -- ^ initialized population
initPop pool n seed = do
                         let g = mkStdGen seed
                             seeds = take n $ randoms g
                         entities <- mapM (genRandom pool) seeds
                         return entities

-- |Binary tournament selection operator.
tournamentSelection :: (Ord s) => [ScoredEntity e s] -- ^ set of entities
                               -> Int -- ^ random seed
                               -> e -- ^ selected entity
tournamentSelection xs seed = if s1 < s2 then x1 else x2
  where
    len = length xs
    g = mkStdGen seed
    is = take 2 $ map (flip mod len) $ randoms g
    [(s1,x1),(s2,x2)] = map ((!!) xs) is

-- |Apply crossover to obtain new entites.
performCrossover :: (Entity e s d p m) => Float -- ^ crossover parameter
                                     -> Int -- ^ number of entities
                                     -> Int -- ^ random seed
                                     -> p -- ^ pool for combining entities
                                     -> [ScoredEntity e s] -- ^ entities
                                     -> m [e] -- combined entities
performCrossover p n seed pool es = do 
    let g = mkStdGen seed
        (selSeeds,seeds) = takeAndDrop (2*2*n) $ randoms g
        (crossSeeds,_) = takeAndDrop (2*n) seeds
        tuples = currify $ map (tournamentSelection es) selSeeds
    resEntities <- zipWithM ($) 
                     (map (uncurry . (crossover pool p)) crossSeeds) 
                     tuples
    return $ take n $ catMaybes $ resEntities

-- |Apply mutation to obtain new entites.
performMutation :: (Entity e s d p m) => Float -- ^ mutation parameter
                                    -> Int -- ^ number of entities
                                    -> Int -- ^ random seed
                                    -> p -- ^ pool for mutating entities
                                    -> [ScoredEntity e s] -- ^ entities
                                    -> m [e] -- mutated entities
performMutation p n seed pool es = do 
    let g = mkStdGen seed
        (selSeeds,seeds) = takeAndDrop (2*n) $ randoms g
        (mutSeeds,_) = takeAndDrop (2*n) seeds
    resEntities <- zipWithM ($) 
                     (map (mutation pool p) mutSeeds) 
                     (map (tournamentSelection es) selSeeds)
    return $ take n $ catMaybes $ resEntities

-- |Score a list of entities.
scoreAll :: (Entity e s d p m) => d -- ^ dataset for scoring entities
                               -> [e] -- ^ universe of known entities
                               -> [e] -- ^ set of entities to score
                               -> m [Maybe s]
scoreAll dataset univEnts ents = do
  scores <- scorePop dataset univEnts ents
  case scores of
    (Just ss) -> return ss
    -- score one by one if scorePop failed
    Nothing   -> mapM (score dataset) ents
 
-- |Function to perform a single evolution step:
--
-- * score all entities in the population
--
-- * combine with best entities so far (archive)
--
-- * sort by fitness
--
-- * create new population using crossover/mutation
--
-- * retain best scoring entities in the archive
evolutionStep :: (Entity e s d p m) => p -- ^ pool for crossover/mutation
                                  -> d -- ^ dataset for scoring entities
                                  -> (Int,Int,Int) -- ^ # of c/m/a entities
                                  -> (Float,Float) -- ^ c/m parameters
                                  -> Bool -- ^ rescore archive in each step?
                                  -> Universe e -- ^ known entities
                                  -> Generation e s -- ^ current generation
                                  -> Int -- ^ seed for next generation
                                  -> m (Universe e, Generation e s) 
                                     -- ^ renewed universe, next generation
evolutionStep pool
              dataset
              (cn,mn,an)
              (crossPar,mutPar)
              rescoreArchive
              universe
              (pop,archive)
              seed = do 
    -- score population
    -- try to score in a single go first
    scores <- scoreAll dataset universe pop
    archive' <- if rescoreArchive
      then return archive
      else do
        let as = map snd archive
        scores' <- scoreAll dataset universe as
        return $ zip scores' as
    let scoredPop = zip scores pop
        -- combine with archive for selection
        combo = scoredPop ++ archive'
        -- split seeds for crossover/mutation selection/seeds
        g = mkStdGen seed
        [crossSeed,mutSeed] = take 2 $ randoms g
    -- apply crossover and mutation
    crossEnts <- performCrossover crossPar cn crossSeed pool combo
    mutEnts <- performMutation mutPar mn mutSeed pool combo
    let -- new population: crossovered + mutated entities
        newPop = crossEnts ++ mutEnts
        -- new archive: best entities so far
        newArchive = take an $ nub $ sortBy (comparing fst) $ combo
        newUniverse = nub $ universe ++ pop
    return (newUniverse, (newPop,newArchive))

-- |Evolution: evaluate generation and continue.
evolution :: (Entity e s d p m) => GAConfig -- ^ configuration for GA
                                -> Universe e -- ^ known entities 
                                -> Generation e s -- ^ current generation
                                -> (   Universe e
                                    -> Generation e s 
                                    -> Int 
                                    -> m (Universe e, Generation e s)
                                   ) -- ^ function that evolves a generation
                                -> [(Int,Int)] -- ^ gen indicies and seeds
                                -> m (Generation e s) -- ^evolved generation
evolution cfg universe gen step ((_,seed):gss) = do
    (universe',nextGen) <- step universe gen seed 
    let (Just fitness, e) = (head $ snd nextGen)
    if isPerfect (e,fitness)
      then return nextGen
      else evolution cfg universe' nextGen step gss
-- no more gen. indices/seeds => quit
evolution _ _ gen _              []    = return gen

-- |Generate file name for checkpoint.
chkptFileName :: GAConfig -- ^ configuration for generation algorithm
              -> (Int,Int) -- ^ generation index and random seed
              -> FilePath -- ^ path of checkpoint file
chkptFileName cfg (gi,seed) = "checkpoints/GA-" 
                           ++ cfgTxt ++ "-gen" 
                           ++ (show gi) ++ "-seed-" 
                           ++ (show seed) ++ ".chk"
  where
    cfgTxt = (show $ getPopSize cfg) ++ "-" ++ 
             (show $ getArchiveSize cfg) ++ "-" ++
             (show $ getCrossoverRate cfg) ++ "-" ++
             (show $ getMutationRate cfg) ++ "-" ++
             (show $ getCrossoverParam cfg) ++ "-" ++
             (show $ getMutationParam cfg)

-- |Checkpoint a single generation.
checkpointGen :: (Entity e s d p m) => GAConfig -- ^ configuraton for GA
                                  -> Int -- ^ generation index
                                  -> Int -- ^ random seed for generation
                                  -> Generation e s -- ^ current generation
                                  -> IO() -- ^ writes to file
checkpointGen cfg index seed (pop,archive) = do
    let txt = show $ (pop,archive)
        fn = chkptFileName cfg (index,seed)
    putStrLn $ "writing checkpoint for gen " 
            ++ (show index) ++ " to " ++ fn
    createDirectoryIfMissing True "checkpoints"
    writeFile fn txt

-- |Evolution: evaluate generation, (maybe) checkpoint, continue.
evolutionChkpt :: (Entity e s d p m, 
                   MonadIO m) => GAConfig -- ^ configuration for GA
                              -> Universe e -- ^ universe of known entities
                              -> Generation e s -- ^ current generation
                              -> (   Universe e 
                                  -> Generation e s 
                                  -> Int 
                                  -> m (Universe e, Generation e s)
                                 ) -- ^ function that evolves a generation
                              -> [(Int,Int)] -- ^ gen indicies and seeds
                              -> m (Generation e s) -- ^ evolved generation
evolutionChkpt cfg universe gen step ((gi,seed):gss) = do
    (universe',newPa@(_,archive')) <- step universe gen seed
    let (Just fitness, e) = head archive'
    -- checkpoint generation if desired
    liftIO $ if (getWithCheckpointing cfg)
      then checkpointGen cfg gi seed newPa
      else return () -- skip checkpoint
    liftIO $ putStrLn $ "best entity (gen. " 
                     ++ show gi ++ "): " ++ (show e) 
                     ++ " [fitness: " ++ show fitness ++ "]"
    -- check for perfect entity
    if isPerfect (e, fitness)
       then do 
               liftIO $ putStrLn $ "perfect entity found, "
                                ++ "finished after " ++ show gi 
                                ++ " generations!"
               return newPa
       else evolutionChkpt cfg universe' newPa step gss

-- no more gen. indices/seeds => quit
evolutionChkpt _ _ gen _ [] = do 
    liftIO $ putStrLn $ "done evolving!"
    return gen

-- |Initialize.
initGA :: (Entity e s d p m) => StdGen  -- ^ random generator
                           -> GAConfig -- ^ configuration for GA
                           -> p -- ^ pool for generating random entities
                           -> m ([e],Int,Int,Int,
                                 Float,Float,[(Int,Int)]
                                ) -- ^ initialization result
initGA g cfg pool = do
    -- generate list of random integers
    let (seed:rs) = randoms g :: [Int]
        ps = getPopSize cfg
    -- initial population
    pop <- initPop pool ps seed
    let -- number of entities generated by crossover/mutation
        cCnt = round $ (getCrossoverRate cfg) * (fromIntegral ps)
        mCnt = round $ (getMutationRate cfg) * (fromIntegral ps)
        -- archive size
        aSize = getArchiveSize cfg
        -- crossover/mutation parameters
        crossPar = getCrossoverParam cfg
        mutPar = getMutationParam cfg
        --  seeds for evolution
        seeds = take (getMaxGenerations cfg) rs
        -- seeds per generation
        genSeeds = zip [0..] seeds
    return (pop, cCnt, mCnt, aSize, crossPar, mutPar, genSeeds)

-- |Do the evolution!
evolve :: (Entity e s d p m) => StdGen -- ^ random generator
                             -> GAConfig -- ^ configuration for GA
                             -> p -- ^ random entities pool
                             -> d -- ^ dataset required to score entities
                             -> m [ScoredEntity e s] -- ^ best entities
evolve g cfg pool dataset = do
    -- initialize
    (pop, cCnt, mCnt, aSize, 
     crossPar, mutPar, genSeeds) <- if not (getWithCheckpointing cfg)
       then initGA g cfg pool
       else error $ "(evolve) No checkpointing support " 
                 ++ "(requires liftIO); see evolveVerbose."
    -- do the evolution
    let rescoreArchive = getRescoreArchive cfg
    (_,resArchive) <- evolution 
                       cfg [] (pop,[]) 
                       (evolutionStep pool dataset 
                                      (cCnt,mCnt,aSize) 
                                      (crossPar,mutPar) 
                                      rescoreArchive   )
                       genSeeds
    -- return best entity
    return resArchive

-- |Try to restore from checkpoint.
--
-- First checkpoint for which a checkpoint file is found is restored.
restoreFromChkpt :: (Entity e s d p m) => GAConfig -- ^ configuration for GA
                                       -> [(Int,Int)] -- ^ gen indices/seeds
                                       -> IO (Maybe (Int,Generation e s)) 
                                          -- ^ restored generation (if any)
restoreFromChkpt cfg ((gi,seed):genSeeds) = do
    chkptFound <- doesFileExist fn
    if chkptFound 
      then do
        txt <- readFile fn
        return $ Just (gi, read txt)
      else restoreFromChkpt cfg genSeeds
  where
    fn = chkptFileName cfg (gi,seed)
restoreFromChkpt _ [] = return Nothing

-- |Do the evolution (supports checkpointing). 
--
-- Requires support for liftIO in monad used.
evolveVerbose :: (Entity e s d p m, 
                  MonadIO m) => StdGen -- ^ random generator
                             -> GAConfig -- ^ configuration for GA
                             -> p -- ^ random entities pool
                             -> d -- ^ dataset required to score entities
                             -> m [ScoredEntity e s] -- ^ best entities
evolveVerbose g cfg pool dataset = do
    -- initialize
    (pop, cCnt, mCnt, aSize, 
     crossPar, mutPar, genSeeds) <- initGA g cfg pool
    let checkpointing = getWithCheckpointing cfg
    -- (maybe) restore from checkpoint
    restored <- liftIO $ if checkpointing
      then restoreFromChkpt cfg (reverse genSeeds) 
      else return Nothing
    let (gi,gen) = if isJust restored
           -- restored pop/archive from checkpoint
           then fromJust restored 
           -- restore failed, new population and empty archive
           else (-1, (pop, []))
        -- filter out seeds from past generations
        genSeeds' = filter ((>gi) . fst) genSeeds
        rescoreArchive = getRescoreArchive cfg
    -- do the evolution
    (_,resArchive) <- evolutionChkpt 
                        cfg [] gen 
                        (evolutionStep pool dataset 
                                       (cCnt,mCnt,aSize) 
                                       (crossPar,mutPar) 
                                       rescoreArchive)
                                       genSeeds'
    -- return best entity 
    return resArchive

-- |Random search.
--
-- Useful to compare with results from genetic algorithm.
randomSearch :: (Entity e s d p m) => StdGen -- ^ random generator
                                   -> Int -- ^ number of random entities
                                   -> p -- ^ random entity pool
                                   -> d -- ^ scoring dataset
                                   -> m [ScoredEntity e s] -- ^ best ents
randomSearch g n pool dataset = do
    let seed = fst $ random g :: Int
    es <- initPop pool n seed
    scores <- scoreAll dataset [] es
    return $ zip scores es