{-# LANGUAGE ScopedTypeVariables      #-}
{-# LANGUAGE TypeOperators            #-}
{-# LANGUAGE FlexibleContexts         #-}
{-# LANGUAGE OverlappingInstances     #-}
{-# LANGUAGE FlexibleInstances        #-}
{-# LANGUAGE GADTs                    #-}
{-# LANGUAGE PolyKinds                #-}

module HarmTrace.Models.Generator (
    Generate(..), GenerateG(..), genGdefault, arbitrary
  , Gen, FrequencyTable, frequencies, frequency
  ) where

import Generics.Instant.Base
import Generics.Instant.Instances ()

import Test.QuickCheck (Gen, frequency, sized)
import Data.Maybe (fromJust)
-- import Debug.Trace (trace)

-- Utility functions for data generation

-- | A frequency table detailing how often certain constructors should be
-- picked. The 'String' corresponds to the constructor name.
type FrequencyTable = [(String,Int)]

frequencies :: [String] -> FrequencyTable -> Int
frequencies [] _ = 0
frequencies (s:ss) ft = let freqs = maybe 1 id (lookup s ft)
                        in freqs + frequencies ss ft


-- Generic empty on Representable (worker)
class Generate a where
  gen' :: FrequencyTable -> Int -> Int -> Maybe (Gen a)

instance Generate U where
  gen' _ _ _ = return . return $ U
instance ( Generate a, ConNames a
         , Generate b, ConNames b) => Generate (a :+: b) where
  gen' ft m n = 
    let aConNames  = conNames (undefined :: a)
        bConNames  = conNames (undefined :: b)
        aFrequency = frequencies aConNames ft
        bFrequency = frequencies bConNames ft
        rl = maybe [] (\x -> [(aFrequency, fmap L x)]) (gen' ft m n)
        rr = maybe [] (\x -> [(bFrequency, fmap R x)]) (gen' ft m n)
    in {- trace ("left  "   ++ show aConNames ++ ": " ++ show aFrequency ++
              "\nright " ++ show bConNames ++ ": " ++ show bFrequency) $ -}
       if null (rl ++ rr) then Nothing else return . frequency $ rl ++ rr
instance (Generate a, Generate b) => Generate (a :*: b) where
  gen' ft m n = do rl <- gen' ft m n
                   rr <- gen' ft m n
                   return $ do x <- rl
                               y <- rr
                               return (x :*: y)

instance (Generate a) => Generate (CEq c p p a) where
  gen' ft m n = fmap (fmap C) (gen' ft m n)

instance Generate (CEq c p q a) where
  gen' _  _ _ = Nothing

instance (GenerateG a) => Generate (Var a) where
  gen' ft m n = fmap (fmap Var) $ genG ft (n `div` m)

instance (GenerateG a) => Generate (Rec a) where
  gen' ft m n = fmap (fmap Rec) $ genG ft (n `div` m)

-- Dispatcher
class GenerateG a where
  genG :: FrequencyTable -> Int -> Maybe (Gen a)

-- | Generic arbitrary function, sized and with custom constructor frequencies.
-- This function does not require any particular nesting order of the sums of
-- the generic representation.
genGdefault :: (Representable a, Generate (Rep a))
            => FrequencyTable -> Int -> Maybe (Gen a)
genGdefault ft = fmap (fmap to) . gen' ft 1

-- | Generic arbitrary function with default sizes and constructor frequencies.
arbitrary :: (Representable a, Generate (Rep a)) => Gen a
arbitrary = sized (fromJust . genGdefault [])

-- Adhoc instances
-- none

-- Generic instances
instance (GenerateG a) => GenerateG (Maybe a)           where genG = genGdefault
instance (GenerateG a) => GenerateG [a]                 where genG = genGdefault
instance (GenerateG a, GenerateG b) => GenerateG (a,b)  where genG = genGdefault


class ConNames a where 
  conNames :: a -> [String]
  conNames _ = []

instance (ConNames a, ConNames b) => ConNames (a :+: b) where
  conNames (_ :: a :+: b) = conNames (undefined :: a) ++
                            conNames (undefined :: b)
instance (ConNames a, Constructor c) => ConNames (CEq c p q a) where
  conNames (x :: (CEq c p q a)) = [conName x]

instance ConNames U
instance ConNames (f :*: g)
instance ConNames (Var a)
instance ConNames (Rec a)

-- | Tree structure to store fixed points as found in the data type.
data Tree a = Leaf a | Node (Tree a) (Tree a)
 deriving Show

foldTree :: (a -> b) -> (b -> b -> b) -> Tree a -> b
foldTree l _ (Leaf x)    = l x
foldTree l n (Node x y)  = (foldTree l n x) `n` (foldTree l n y)

sumTree :: Tree Int -> Int
sumTree = foldTree id (+)

-- | The class to compute fixed points.
class Fixpoints a where 
  hFixpoints :: a -> Tree Int

instance (Fixpoints a, Fixpoints b) => Fixpoints (a :+: b) where
  hFixpoints (_ :: a :+: b) = Node (hFixpoints (undefined :: a))
                                   (hFixpoints (undefined :: b))
instance (Fixpoints a) => Fixpoints (CEq c p q a) where
  hFixpoints (_ :: CEq c p q a) = hFixpoints (undefined :: a)

instance (Fixpoints a, Fixpoints b) => Fixpoints (a :*: b) where
  hFixpoints (_ :: a :*: b) = 
    let Leaf m = hFixpoints (undefined :: a)
        Leaf n = hFixpoints (undefined :: b)
    in Leaf (m + n)

instance Fixpoints (Rec a) where hFixpoints _ = Leaf 1
instance Fixpoints (Var a) where hFixpoints _ = Leaf 0
instance Fixpoints U       where hFixpoints _ = Leaf 0

-- | Return a tree structure of the fixed points of a datatype
fixpoints :: (Representable a, Fixpoints (Rep a)) => a -> Tree Int
fixpoints x = hFixpoints (undefined `asTypeOf` (from x))