{-#LANGUAGE DeriveDataTypeable, TemplateHaskell #-}

-- | Everything you need to construct an enumeration for an algebraic type.
-- Just define each constructor using pure for nullary constructors and 
-- unary and funcurry for positive arity constructors, then combine the 
-- constructors with consts. Example:
-- 
-- @
--  instance Enumerable a => Enumerable [a] where
--    enumerate = consts [unary (funcurry (:)), pure []]
-- @
--
-- There's also a handy Template Haskell function for automatic derivation.


module Test.Feat.Class (
  Enumerable(..),
  
  -- ** Building instances
  Constructor,
  nullary,
  unary,
  funcurry,
  consts,
  
  -- ** Accessing the enumerator of an instance
  optimised,
  
  -- *** Free pairs
  FreePair(..),
  
  
  -- ** Deriving instances with template Haskell
  deriveEnumerable,
  deriveEnumerable',
  ConstructorDeriv,
  dAll,
  dExcluding,
  dExcept
  -- autoCon,
  -- autoCons
  ) where

-- testing-feat
import Test.Feat.Enumerate
import Test.Feat.Internals.Tag(Tag(Class))
import Test.Feat.Internals.Derive
import Test.Feat.Internals.Newtypes
-- base
import Data.Typeable
import Data.Monoid
-- template-haskell
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
-- base - only for instances
import Data.Word
import Data.Int
import Data.Bits
import Data.Ratio

-- | A class of functionally enumerable types
class Typeable a => Enumerable a where
  -- | This is the interface for defining an instance. Memoisation needs to 
  -- be ensured e.g. using 'mempay' but sharing is handled automatically by 
  -- the default implementation of 'shared'.
  enumerate  :: Enumerate a
  
  -- | Version of enumerate that ensures it is shared between
  -- all accessing functions. Should always be used when 
  -- combining enumerations.
  -- Should typically be left to default behaviour.
  shared     :: Enumerate a
  shared  = tagShare Class enumerate

-- | An optimised version of enumerate. Used by all
-- library functions that access enumerated values (but not 
-- by combining functions). Library functions should ensure that 
-- @optimised@ is not reevaluated.
optimised :: Enumerable a => Enumerate a
optimised = optimise shared   

-- | A free pair constructor. The cost of constructing a free pair
-- is equal to the sum of the costs of its components. 
newtype FreePair a b = Free {free :: (a,b)} 
  deriving (Show, Typeable)

-- | Uncurry a function (typically a constructor) to a function on free pairs.
funcurry :: (a -> b -> c) -> FreePair a b -> c
funcurry f = uncurry f . free

instance (Enumerable a, Enumerable b) => 
         Enumerable (FreePair a b) where
  enumerate = mem $ curry Free <$> shared <*> shared

type Constructor = Enumerate
  
-- | For nullary constructors such as @True@ and @[]@.
nullary :: a -> Constructor a
nullary = pure

-- | For any non-nullary constructor. Apply 'funcurry' until the type of
-- the result is unary (i.e. n-1 times where n is the number of fields 
-- of the constructor).
unary :: Enumerable a => (a -> b) -> Constructor b
unary f = f <$> shared

-- | Produces the enumeration of a type given the enumerators for each of its
-- constructors. The result of 'unary' should typically not be used 
-- directly in an instance even if it only has one constructor. So you 
-- should apply consts even in that case. 
consts :: [Constructor a] -> Enumerate a
consts xs = mempay $ mconcat xs 


--------------------------------------------------------------------
-- Automatic derivation

-- | Derive an instance of Enumberable with Template Haskell.
deriveEnumerable :: Name -> Q [Dec]
deriveEnumerable = deriveEnumerable' . dAll
  -- fmap return . instanceFor ''Enumerable [enumDef]

type ConstructorDeriv = (Name, [(Name, ExpQ)])
dAll :: Name -> ConstructorDeriv
dAll n = (n,[])
dExcluding :: Name -> ConstructorDeriv -> ConstructorDeriv
dExcluding n (t,nrs) = (t,(n,[|mempty|]):nrs)
dExcept :: Name -> ExpQ -> ConstructorDeriv -> ConstructorDeriv
dExcept n e (t,nrs) = (t,(n,e):nrs)

-- | Derive an instance of Enumberable with Template Haskell, with 
-- rules for some specific constructors
deriveEnumerable' :: ConstructorDeriv -> Q [Dec]
deriveEnumerable' (n,cse) =
  fmap return $ instanceFor ''Enumerable [enumDef] n 
  where
    enumDef :: [(Name,[Type])] -> Q Dec
    enumDef cons = do
      sanityCheck
      fmap mk_freqs_binding [|consts $ex |] 
      where
        ex = listE $ map cone cons
        cone xs@(n,_) = maybe (cone' xs) id $ lookup n cse
        cone' (n,[]) = [|nullary $(conE n)|]
        cone' (n,_:vs) = 
          [|unary $(foldr appE (conE n) (map (const [|funcurry|] ) vs) )|]
        mk_freqs_binding :: Exp -> Dec
        mk_freqs_binding e = ValD (VarP 'enumerate ) (NormalB e) []
        sanityCheck = case filter (`notElem` map fst cons) (map fst cse) of
          [] -> return ()
          xs -> error $ "Invalid constructors for "++show n++": "++show xs
        
-- do
--          (_,ns,_) <- extractData n
--          if all (map snd nse) (`elem` ns) 
--            then return () 
--            else error $ "Invalid constructors for "++show n++": "++
--                        show (filter (`notElem` ns) (map fst nse)) 








---------------------------------------------------------------------
-- Instances


(let 
  it = mapM (instanceFor ''Enumerable [enumDef]) 
    [ ''[] 
    , ''Bool
    , ''()
    , ''(,)
    , ''(,,)
    , ''(,,,)
    , ''(,,,,)
    , ''(,,,,,)
    , ''(,,,,,,) -- This is as far as typeable goes...
    , ''Either
    , ''Maybe
    , ''Ordering
    ]
  -- Circumventing the stage restrictions by means of code repetition.
  enumDef :: [(Name,[Type])] -> Q Dec
  enumDef cons = fmap mk_freqs_binding [|consts $ex |] where
    ex = listE $ map cone cons
    cone (n,[]) = [|pure $(conE n)|]
    cone (n,_:vs) = 
      [|unary $(foldr appE (conE n) (map (const [|funcurry|] ) vs) )|]
    mk_freqs_binding :: Exp -> Dec
    mk_freqs_binding e = ValD (VarP 'enumerate) (NormalB e) []
  in it)
  


-- This instance is quite important. It needs to be exponential for 
-- the other instances to work.
instance Infinite a => Enumerable (Nat a) where 
  enumerate = let e = Enumerate{
    card = crd,
    select = sel,
    optimal = return e} in e where
      crd p
        | p <= 0     = 0
        | p == 1     = 1
        | otherwise  = 2^(p-2)
      sel :: Num a => Part -> Index -> Nat a
      sel 1 0 = Nat 0
      sel p i = Nat $ 2^(p-2) + fromInteger i


-- This instance is used by the Int* instances and needs to be exponential as 
-- well.
instance Enumerable Integer where 
  enumerate = unary f  where
    f (Free (b,Nat i)) = if b then -i-1 else i

instance (Infinite a, Enumerable a) => Enumerable (NonZero a) where 
  enumerate = unary (\a -> NonZero $ if a >= 0 then a+1 else a)            

-- An exported version would have to use $tag instead of Class
word :: (Bits a, Integral a) => Enumerate a 
word = e where
  e = cutOff (bitSize' e+1) $ unary (fromInteger . nat)
  
int :: (Bits a, Integral a) => Enumerate a 
int = e where
  e = cutOff (bitSize' e+1) $ unary fromInteger

cutOff :: Int -> Enumerate a -> Enumerate a 
cutOff n e = e{
  card = \p -> if p > n then 0 else card e p, 
  optimal = fmap (cutOff n) $ optimal e
  }

bitSize' :: Bits a => f a -> Int
bitSize' f = hlp undefined f where
  hlp :: Bits a => a -> f a -> Int
  hlp a _ = bitSize a

instance Enumerable Word where
  enumerate = word
instance Enumerable Word8 where
  enumerate = word
instance Enumerable Word16 where
  enumerate = word
instance Enumerable Word32 where
  enumerate = word
instance Enumerable Word64 where
  enumerate = word

instance Enumerable Int where
  enumerate = int
instance Enumerable Int8 where
  enumerate = int
instance Enumerable Int16 where
  enumerate = int
instance Enumerable Int32 where
  enumerate = int
instance Enumerable Int64 where
  enumerate = int

-- | Not injective
instance Enumerable Double where
  enumerate = unary (funcurry encodeFloat)

-- | Not injective
instance Enumerable Float where
  enumerate = unary (funcurry encodeFloat)

-- This should be fixed with a bijective funtion.
-- | Not injective
instance (Infinite a, Enumerable a) => Enumerable (Ratio a) where
  enumerate = unary $ funcurry $ \a b -> a % nonZero b

-- | Contains only ASCII characters
instance Enumerable Char where
  enumerate = cutOff 8 $ unary (toEnum . fromIntegral :: Word -> Char)