{-#LANGUAGE DeriveDataTypeable, TemplateHaskell #-}

-- | Everything you need to construct an enumeration for an algebraic type.
-- Just define each constructor using pure for nullary constructors and 
-- unary and funcurry for positive arity constructors, then combine the 
-- constructors with consts. Example:
-- 
-- @
--  instance Enumerable a => Enumerable [a] where
--    enumerate = consts [unary (funcurry (:)), pure []]
-- @
--
-- There's also a handy Template Haskell function for automatic derivation.


module Test.Feat.Class (
  Enumerable(..),
  
  -- ** Building instances
  Constructor,
  nullary,
  unary,
  funcurry,
  consts,
  
  -- ** Accessing the enumerator of an instance
  optimised,
  
  -- *** Free pairs
  FreePair(..),
  
  
  -- ** Deriving instances with template haskell
  deriveEnumerable,
  -- autoCon,
  -- autoCons
  
  

  ) where

-- testing-feat
import Test.Feat.Enumerate
import Test.Feat.Internals.Tag(Tag(Class))
import Test.Feat.Internals.Derive
-- base
import Data.Typeable
-- template-haskell
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
-- base - only for instances
import Data.Word
import Data.Int
import Data.Bits

-- | A class of functionally enumerable types
class Typeable a => Enumerable a where
  -- | This is the interface for defining an instance. Memoisation needs to 
  -- be ensured e.g. using 'mempay' but sharing is handled automatically by 
  -- the default implementation of 'shared'.
  enumerate  :: Enumerate a
  
  -- | Version of enumerate that ensures it is shared between
  -- all accessing functions. Should alwasy be used when 
  -- combining enumerations.
  -- Should typically be left to default behaviour.
  shared     :: Enumerate a
  shared  = tagShare Class enumerate

-- | An optimised version of enumerate. Used by all
-- library functions that access enumerated values (but not 
-- by combining functions). Library functions should ensure that 
-- @optimised@ is not reevaluated.
optimised :: Enumerable a => Enumerate a
optimised = optimise shared   

-- | A free pair constructor. The cost of constructing a free pair
-- is equal to the sum of the costs of its components. 
newtype FreePair a b = Free {free :: (a,b)} 
  deriving (Show, Typeable)

-- | Uncurry a function (typically a constructor) to a function on free pairs.
funcurry :: (a -> b -> c) -> FreePair a b -> c
funcurry f = uncurry f . free

instance (Enumerable a, Enumerable b) => 
         Enumerable (FreePair a b) where
  enumerate = mem $ curry Free <$> shared <*> shared

type Constructor = Enumerate
  
-- | For nullary constructors such as @True@ and @[]@.
nullary :: a -> Constructor a
nullary = pure

-- | For any non-nullary constructor. Apply 'funcurry' until the type of
-- the result is unary (i.e. n-1 times where n is the number of fields 
-- of the constructor).
unary :: Enumerable a => (a -> b) -> Constructor b
unary f = f <$> shared

-- | Produces the enumeration of a type given the enumerators for each of its
-- constructors. The result of 'unary' should typically not be used 
-- directly in an instance even if it only has one constructor. So you 
-- should apply consts even in that case. 
consts :: [Constructor a] -> Enumerate a
consts xs = mempay $ mconcat xs 


--------------------------------------------------------------------
-- Automatic derivation

-- | Derive an instance of Enumberable with Template Haskell.
deriveEnumerable :: Name -> Q [Dec]
deriveEnumerable = fmap return . instanceFor ''Enumerable [enumDef]

-- -- | Derive the enumeration of a single constructor. Useful 
-- if 'deriveEnumerable' does not work for all constructors. 
-- autoCon :: Name -> Q Exp
-- autoCon = undefined

-- -- | Splices a list of automatically derived constructors.
-- autoCons :: [Name] -> Q Exp
-- autoCons = listE . map autoCon

enumDef :: [(Name,[Type])] -> [Q Dec]
enumDef cons = [fmap mk_freqs_binding [|consts $ex |]] where
  ex = listE $ map cone cons
  cone (n,[]) = [|pure $(conE n)|]
  cone (n,_:vs) = 
    [|unary $(foldr appE (conE n) (map (const [|funcurry|] ) vs) )|]
  mk_freqs_binding :: Exp -> Dec
  mk_freqs_binding e = ValD (VarP 'enumerate) (NormalB e) []





---------------------------------------------------------------------
-- Instances


(let 
  it = mapM (instanceFor ''Enumerable [enumDef]) 
    [ ''[] 
    , ''Bool
    , ''()
    , ''(,)
    , ''(,,)
    , ''(,,,)
    , ''(,,,,)
    , ''(,,,,,)
    , ''(,,,,,,) -- This is as far as typeable goes...
    , ''Either
    , ''Maybe
    , ''Ordering
    ]
  -- Circumventing the stage restrictions by means of code repetition.
  enumDef :: [(Name,[Type])] -> [Q Dec]
  enumDef cons = [fmap mk_freqs_binding [|consts $ex |]] where
    ex = listE $ map cone cons
    cone (n,[]) = [|pure $(conE n)|]
    cone (n,_:vs) = 
      [|unary $(foldr appE (conE n) (map (const [|funcurry|] ) vs) )|]
    mk_freqs_binding :: Exp -> Dec
    mk_freqs_binding e = ValD (VarP 'enumerate) (NormalB e) []
  in it)
  


-- This instance is quite important. It needs to be exponential for 
-- the other instances to work.
newtype Natural = Natural {natural :: Integer} deriving (Typeable, Show)
instance Enumerable Natural where 
  enumerate = let e = Enumerate{
    card = crd,
    select = sel,
    optimal = return e} in e where
      crd p
        | p <= 0     = 0
        | p == 1     = 1
        | otherwise  = 2^(p-2)
      sel 1 0 = Natural 0
      sel p i = Natural $ 2^(p-2) + i

-- This instance is used by the Int* instances and needs to be exponential as 
-- well.
instance Enumerable Integer where 
  enumerate = unary f  where
    f (Free (b,Natural i)) = if b then -i-1 else i
           

-- An exported version would have to use $tag instead of Class
word :: (Bits a, Integral a) => Enumerate a 
word = e where
  e = cutOff (bitSize' e+1) $ unary (fromInteger . natural)
  
int :: (Bits a, Integral a) => Enumerate a 
int = e where
  e = cutOff (bitSize' e+1) $ unary fromInteger

cutOff :: Int -> Enumerate a -> Enumerate a 
cutOff n e = e{
  card = \p -> if p > n then 0 else card e p, 
  optimal = fmap (cutOff n) $ optimal e
  }

bitSize' :: Bits a => f a -> Int
bitSize' f = hlp undefined f where
  hlp :: Bits a => a -> f a -> Int
  hlp a _ = bitSize a

instance Enumerable Word where
  enumerate = word
instance Enumerable Word8 where
  enumerate = word
instance Enumerable Word16 where
  enumerate = word
instance Enumerable Word32 where
  enumerate = word
instance Enumerable Word64 where
  enumerate = word

instance Enumerable Int where
  enumerate = int
instance Enumerable Int8 where
  enumerate = int
instance Enumerable Int16 where
  enumerate = int
instance Enumerable Int32 where
  enumerate = int
instance Enumerable Int64 where
  enumerate = int

-- | Not injective
instance Enumerable Double where
  enumerate = unary (funcurry encodeFloat)

-- | Not injective
instance Enumerable Float where
  enumerate = unary (funcurry encodeFloat)

-- | Contains only ASCII characters
instance Enumerable Char where
  enumerate = cutOff 8 $ unary (toEnum . fromIntegral :: Word -> Char)