{-#LANGUAGE DeriveDataTypeable, TemplateHaskell #-} -- | Everything you need to construct an enumeration for an algebraic type. -- Just define each constructor using pure for nullary constructors and -- unary and funcurry for positive arity constructors, then combine the -- constructors with consts. Example: -- -- @ -- instance Enumerable a => Enumerable [a] where -- enumerate = consts [unary (funcurry (:)), pure []] -- @ -- -- There's also a handy Template Haskell function for automatic derivation. module Test.Feat.Class ( Enumerable(..), -- ** Building instances Constructor, nullary, unary, funcurry, consts, -- ** Accessing the enumerator of an instance shared, optimal, -- *** Free pairs FreePair(..), -- ** Deriving instances with template Haskell deriveEnumerable, deriveEnumerable', ConstructorDeriv, dAll, dExcluding, dExcept -- autoCon, -- autoCons ) where -- testing-feat import Test.Feat.Enumerate import Test.Feat.Internals.Tag(Tag(Class)) import Test.Feat.Internals.Derive import Test.Feat.Internals.Newtypes -- base import Data.Typeable import Data.Monoid -- template-haskell import Language.Haskell.TH import Language.Haskell.TH.Syntax -- base - only for instances import Data.Word import Data.Int import Data.Bits import Data.Ratio -- | A class of functionally enumerable types class Typeable a => Enumerable a where -- | This is the interface for defining an instance. When combining -- enumerations use 'shared' instead and when accessing the data of -- enumerations use 'optimal'. enumerate :: Enumerate a -- | Version of 'enumerate' that ensures that the enumeration is shared -- between all accesses. Should always be used when -- combining enumerations. shared :: Enumerable a => Enumerate a shared = eShare Class enumerate -- | An optimal version of enumerate. Used by all -- library functions that access enumerated values (but not -- by combining functions). Library functions should ensure that -- @optimal@ is not reevaluated. optimal :: Enumerable a => Enumerate a optimal = optimise shared -- | A free pair constructor. The cost of constructing a free pair -- is equal to the sum of the costs of its components. newtype FreePair a b = Free {free :: (a,b)} deriving (Show, Typeable) -- | Uncurry a function (typically a constructor) to a function on free pairs. funcurry :: (a -> b -> c) -> FreePair a b -> c funcurry f = uncurry f . free instance (Enumerable a, Enumerable b) => Enumerable (FreePair a b) where enumerate = curry Free <$> shared <*> shared type Constructor = Enumerate -- | For nullary constructors such as @True@ and @[]@. nullary :: a -> Constructor a nullary = pure -- | For any non-nullary constructor. Apply 'funcurry' until the type of -- the result is unary (i.e. n-1 times where n is the number of fields -- of the constructor). unary :: Enumerable a => (a -> b) -> Constructor b unary f = f <$> shared -- | Produces the enumeration of a type given the enumerators for each of its -- constructors. The result of 'unary' should typically not be used -- directly in an instance even if it only has one constructor. So you -- should apply consts even in that case. consts :: [Constructor a] -> Enumerate a consts xs = pay $ mconcat xs -------------------------------------------------------------------- -- Automatic derivation -- | Derive an instance of Enumberable with Template Haskell. To derive -- an instance for @Enumerable A@, just put this as a top level declaration -- in your module (with the TemplateHaskell extension enabled): -- -- @ -- deriveEnumerable ''A -- @ deriveEnumerable :: Name -> Q [Dec] deriveEnumerable = deriveEnumerable' . dAll type ConstructorDeriv = (Name, [(Name, ExpQ)]) dAll :: Name -> ConstructorDeriv dAll n = (n,[]) dExcluding :: Name -> ConstructorDeriv -> ConstructorDeriv dExcluding n (t,nrs) = (t,(n,[|mempty|]):nrs) dExcept :: Name -> ExpQ -> ConstructorDeriv -> ConstructorDeriv dExcept n e (t,nrs) = (t,(n,e):nrs) -- | Derive an instance of Enumberable with Template Haskell, with -- rules for some specific constructors deriveEnumerable' :: ConstructorDeriv -> Q [Dec] deriveEnumerable' (n,cse) = fmap return $ instanceFor ''Enumerable [enumDef] n where enumDef :: [(Name,[Type])] -> Q Dec enumDef cons = do sanityCheck fmap mk_freqs_binding [|consts $ex |] where ex = listE $ map cone cons cone xs@(n,_) = maybe (cone' xs) id $ lookup n cse cone' (n,[]) = [|nullary $(conE n)|] cone' (n,_:vs) = [|unary $(foldr appE (conE n) (map (const [|funcurry|] ) vs) )|] mk_freqs_binding :: Exp -> Dec mk_freqs_binding e = ValD (VarP 'enumerate ) (NormalB e) [] sanityCheck = case filter (`notElem` map fst cons) (map fst cse) of [] -> return () xs -> error $ "Invalid constructors for "++show n++": "++show xs --------------------------------------------------------------------- -- Instances (let it = mapM (instanceFor ''Enumerable [enumDef]) [ ''[] , ''Bool , ''() , ''(,) , ''(,,) , ''(,,,) , ''(,,,,) , ''(,,,,,) , ''(,,,,,,) -- This is as far as typeable goes... , ''Either , ''Maybe , ''Ordering ] -- Circumventing the stage restrictions by means of code repetition. enumDef :: [(Name,[Type])] -> Q Dec enumDef cons = fmap mk_freqs_binding [|consts $ex |] where ex = listE $ map cone cons cone (n,[]) = [|pure $(conE n)|] cone (n,_:vs) = [|unary $(foldr appE (conE n) (map (const [|funcurry|] ) vs) )|] mk_freqs_binding :: Exp -> Dec mk_freqs_binding e = ValD (VarP 'enumerate) (NormalB e) [] in it) simpleEnum car sel = let e = Enumerate (toRev$ map (\p -> Finite (car p) (sel p)) [0..]) (return e) in e -- This instance is quite important. It needs to be exponential for -- the other instances to work. instance Infinite a => Enumerable (Nat a) where enumerate = simpleEnum crd sel where crd p | p <= 0 = 0 | p == 1 = 1 | otherwise = 2^(p-2) sel :: Num a => Int -> Index -> Nat a sel 1 0 = Nat 0 sel p i = Nat $ 2^(p-2) + fromInteger i -- This instance is used by the Int* instances and needs to be exponential as -- well. instance Enumerable Integer where enumerate = unary f where f (Free (b,Nat i)) = if b then -i-1 else i instance (Infinite a, Enumerable a) => Enumerable (NonZero a) where enumerate = unary (\a -> NonZero $ if a >= 0 then a+1 else a) -- An exported version would have to use $tag instead of Class word :: (Bits a, Integral a) => Enumerate a word = e where e = cutOff (bitSize' e+1) $ unary (fromInteger . nat) int :: (Bits a, Integral a) => Enumerate a int = e where e = cutOff (bitSize' e+1) $ unary fromInteger cutOff :: Int -> Enumerate a -> Enumerate a cutOff n e = Enumerate prts (fmap (cutOff n) (optimiser e)) where prts = toRev$ take n $ parts e bitSize' :: Bits a => f a -> Int bitSize' f = hlp undefined f where hlp :: Bits a => a -> f a -> Int hlp a _ = bitSize a instance Enumerable Word where enumerate = word instance Enumerable Word8 where enumerate = word instance Enumerable Word16 where enumerate = word instance Enumerable Word32 where enumerate = word instance Enumerable Word64 where enumerate = word instance Enumerable Int where enumerate = int instance Enumerable Int8 where enumerate = int instance Enumerable Int16 where enumerate = int instance Enumerable Int32 where enumerate = int instance Enumerable Int64 where enumerate = int -- | Not injective instance Enumerable Double where enumerate = unary (funcurry encodeFloat) -- | Not injective instance Enumerable Float where enumerate = unary (funcurry encodeFloat) -- This should be fixed with a bijective function. -- | Not injective instance (Infinite a, Enumerable a) => Enumerable (Ratio a) where enumerate = unary $ funcurry $ \a b -> a % nonZero b -- | Contains only ASCII characters instance Enumerable Char where enumerate = cutOff 8 $ unary (toEnum . fromIntegral :: Word -> Char)