{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -- | Finite numbers. -- -- This module is designed to be imported as -- -- @ -- import Data.Fin (Fin (..)) -- import qualified Data.Fin as Fin -- @ -- module Data.Fin ( Fin (..), cata, -- * Showing explicitShow, explicitShowsPrec, -- * Conversions toNat, fromNat, toNatural, toInteger, -- * Interesting mirror, inverse, universe, inlineUniverse, universe1, inlineUniverse1, absurd, boring, -- * Plus weakenLeft, weakenLeft1, weakenRight, weakenRight1, append, split, -- * Min and max isMin, isMax, -- * Aliases fin0, fin1, fin2, fin3, fin4, fin5, fin6, fin7, fin8, fin9, ) where import Control.DeepSeq (NFData (..)) import Data.Bifunctor (bimap) import Data.Hashable (Hashable (..)) import Data.List.NonEmpty (NonEmpty (..)) import Data.Proxy (Proxy (..)) import Data.Type.Nat (Nat (..)) import Data.Typeable (Typeable) import GHC.Exception (ArithException (..), throw) import Numeric.Natural (Natural) import qualified Data.List.NonEmpty as NE import qualified Data.Type.Nat as N import qualified Test.QuickCheck as QC ------------------------------------------------------------------------------- -- Type ------------------------------------------------------------------------------- -- | Finite numbers: @[0..n-1]@. data Fin (n :: Nat) where FZ :: Fin ('S n) FS :: Fin n -> Fin ('S n) deriving (Typeable) ------------------------------------------------------------------------------- -- Instances ------------------------------------------------------------------------------- deriving instance Eq (Fin n) deriving instance Ord (Fin n) -- | 'Fin' is printed as 'Natural'. -- -- To see explicit structure, use 'explicitShow' or 'explicitShowsPrec' instance Show (Fin n) where showsPrec d = showsPrec d . toNatural -- | Operations module @n@. -- -- >>> map fromInteger [0, 1, 2, 3, 4, -5] :: [Fin N.Nat3] -- [0,1,2,0,1,1] -- -- >>> fromInteger 42 :: Fin N.Nat0 -- *** Exception: divide by zero -- ... -- -- >>> signum (FZ :: Fin N.Nat1) -- 0 -- -- >>> signum (3 :: Fin N.Nat4) -- 1 -- -- >>> 2 + 3 :: Fin N.Nat4 -- 1 -- -- >>> 2 * 3 :: Fin N.Nat4 -- 2 -- instance N.SNatI n => Num (Fin n) where abs = id signum FZ = FZ signum (FS FZ) = FS FZ signum (FS (FS _)) = FS FZ fromInteger = unsafeFromNum . (`mod` N.reflectToNum (Proxy :: Proxy n)) n + m = fromInteger (toInteger n + toInteger m) n * m = fromInteger (toInteger n * toInteger m) n - m = fromInteger (toInteger n - toInteger m) negate = fromInteger . negate . toInteger instance N.SNatI n => Real (Fin n) where toRational = cata 0 succ -- | 'quot' works only on @'Fin' n@ where @n@ is prime. instance N.SNatI n => Integral (Fin n) where toInteger = cata 0 succ quotRem a b = (quot a b, 0) quot a b = a * inverse b -- | Mirror the values, 'minBound' becomes 'maxBound', etc. -- -- >>> map mirror universe :: [Fin N.Nat4] -- [3,2,1,0] -- -- >>> reverse universe :: [Fin N.Nat4] -- [3,2,1,0] -- -- @since 0.1.1 -- mirror :: forall n. N.InlineInduction n => Fin n -> Fin n mirror = getMirror (N.inlineInduction start step) where start :: Mirror 'Z start = Mirror id step :: forall m. N.InlineInduction m => Mirror m -> Mirror ('S m) step (Mirror rec) = Mirror $ \n -> case n of FZ -> getMaxBound (N.inlineInduction (MaxBound FZ) (MaxBound . FS . getMaxBound)) FS m -> weakenLeft1 (rec m) newtype Mirror n = Mirror { getMirror :: Fin n -> Fin n } -- | Multiplicative inverse. -- -- Works for @'Fin' n@ where @n@ is coprime with an argument, i.e. in general when @n@ is prime. -- -- >>> map inverse universe :: [Fin N.Nat5] -- [0,1,3,2,4] -- -- >>> zipWith (*) universe (map inverse universe) :: [Fin N.Nat5] -- [0,1,1,1,1] -- -- Adaptation of [pseudo-code in Wikipedia](https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers) -- inverse :: forall n. N.SNatI n => Fin n -> Fin n inverse = fromInteger . iter 0 n 1 . toInteger where n = N.reflectToNum (Proxy :: Proxy n) iter t _ _ 0 | t < 0 = t + n | otherwise = t iter t r t' r' = let q = r `div` r' in iter t' r' (t - q * t') (r - q * r') instance N.SNatI n => Enum (Fin n) where fromEnum = go where go :: Fin m -> Int go FZ = 0 go (FS n) = succ (go n) toEnum = unsafeFromNum instance (n ~ 'S m, N.SNatI m) => Bounded (Fin n) where minBound = FZ maxBound = getMaxBound $ N.induction (MaxBound FZ) (MaxBound . FS . getMaxBound) newtype MaxBound n = MaxBound { getMaxBound :: Fin ('S n) } instance NFData (Fin n) where rnf FZ = () rnf (FS n) = rnf n instance Hashable (Fin n) where hashWithSalt salt = hashWithSalt salt . cata (0 :: Integer) succ ------------------------------------------------------------------------------- -- QuickCheck ------------------------------------------------------------------------------- instance (n ~ 'S m, N.SNatI m) => QC.Arbitrary (Fin n) where arbitrary = getArb $ N.induction (Arb (return FZ)) step where step :: forall p. N.SNatI p => Arb p -> Arb ('S p) step (Arb p) = Arb $ QC.frequency [ (1, return FZ) , (N.reflectToNum (Proxy :: Proxy p), fmap FS p) ] shrink = shrink shrink :: Fin n -> [Fin n] shrink FZ = [] shrink (FS FZ) = [FZ] shrink (FS n) = map FS (shrink n) newtype Arb n = Arb { getArb :: QC.Gen (Fin ('S n)) } instance QC.CoArbitrary (Fin n) where coarbitrary FZ = QC.variant (0 :: Int) coarbitrary (FS n) = QC.variant (1 :: Int) . QC.coarbitrary n instance (n ~ 'S m, N.SNatI m) => QC.Function (Fin n) where function = case N.snat :: N.SNat m of N.SZ -> QC.functionMap (\FZ -> ()) (\() -> FZ) N.SS -> QC.functionMap isMin (maybe FZ FS) -- TODO: https://github.com/nick8325/quickcheck/pull/283 -- newtype Fun b m = Fun { getFun :: (Fin ('S m) -> b) -> Fin ('S m) QC.:-> b } ------------------------------------------------------------------------------- -- Showing ------------------------------------------------------------------------------- -- | 'show' displaying a structure of 'Fin'. -- -- >>> explicitShow (0 :: Fin N.Nat1) -- "FZ" -- -- >>> explicitShow (2 :: Fin N.Nat3) -- "FS (FS FZ)" -- explicitShow :: Fin n -> String explicitShow n = explicitShowsPrec 0 n "" -- | 'showsPrec' displaying a structure of 'Fin'. explicitShowsPrec :: Int -> Fin n -> ShowS explicitShowsPrec _ FZ = showString "FZ" explicitShowsPrec d (FS n) = showParen (d > 10) $ showString "FS " . explicitShowsPrec 11 n ------------------------------------------------------------------------------- -- Conversions ------------------------------------------------------------------------------- -- | Fold 'Fin'. cata :: forall a n. a -> (a -> a) -> Fin n -> a cata z f = go where go :: Fin m -> a go FZ = z go (FS n) = f (go n) -- | Convert to 'Nat'. toNat :: Fin n -> N.Nat toNat = cata Z S -- | Convert from 'Nat'. -- -- >>> fromNat N.nat1 :: Maybe (Fin N.Nat2) -- Just 1 -- -- >>> fromNat N.nat1 :: Maybe (Fin N.Nat1) -- Nothing -- fromNat :: N.SNatI n => N.Nat -> Maybe (Fin n) fromNat = appNatToFin (N.induction start step) where start :: NatToFin 'Z start = NatToFin $ const Nothing step :: NatToFin n -> NatToFin ('S n) step (NatToFin f) = NatToFin $ \n -> case n of Z -> Just FZ S m -> fmap FS (f m) newtype NatToFin n = NatToFin { appNatToFin :: N.Nat -> Maybe (Fin n) } -- | Convert to 'Natural'. toNatural :: Fin n -> Natural toNatural = cata 0 succ -- | Convert from any 'Ord' 'Num'. unsafeFromNum :: forall n i. (Num i, Ord i, N.SNatI n) => i -> Fin n unsafeFromNum = appUnsafeFromNum (N.induction start step) where start :: UnsafeFromNum i 'Z start = UnsafeFromNum $ \n -> case compare n 0 of LT -> throw Underflow EQ -> throw Overflow GT -> throw Overflow step :: UnsafeFromNum i m -> UnsafeFromNum i ('S m) step (UnsafeFromNum f) = UnsafeFromNum $ \n -> case compare n 0 of EQ -> FZ GT -> FS (f (n - 1)) LT -> throw Underflow newtype UnsafeFromNum i n = UnsafeFromNum { appUnsafeFromNum :: i -> Fin n } ------------------------------------------------------------------------------- -- "Interesting" stuff ------------------------------------------------------------------------------- -- | All values. @[minBound .. maxBound]@ won't work for @'Fin' 'N.Nat0'@. -- -- >>> universe :: [Fin N.Nat3] -- [0,1,2] universe :: N.SNatI n => [Fin n] universe = getUniverse $ N.induction (Universe []) step where step :: Universe n -> Universe ('S n) step (Universe xs) = Universe (FZ : map FS xs) -- | Like 'universe' but 'NonEmpty'. -- -- >>> universe1 :: NonEmpty (Fin N.Nat3) -- 0 :| [1,2] universe1 :: N.SNatI n => NonEmpty (Fin ('S n)) universe1 = getUniverse1 $ N.induction (Universe1 (FZ :| [])) step where step :: Universe1 n -> Universe1 ('S n) step (Universe1 xs) = Universe1 (NE.cons FZ (fmap FS xs)) -- | 'universe' which will be fully inlined, if @n@ is known at compile time. -- -- >>> inlineUniverse :: [Fin N.Nat3] -- [0,1,2] inlineUniverse :: N.InlineInduction n => [Fin n] inlineUniverse = getUniverse $ N.inlineInduction (Universe []) step where step :: Universe n -> Universe ('S n) step (Universe xs) = Universe (FZ : map FS xs) -- | >>> inlineUniverse1 :: NonEmpty (Fin N.Nat3) -- 0 :| [1,2] inlineUniverse1 :: N.InlineInduction n => NonEmpty (Fin ('S n)) inlineUniverse1 = getUniverse1 $ N.inlineInduction (Universe1 (FZ :| [])) step where step :: Universe1 n -> Universe1 ('S n) step (Universe1 xs) = Universe1 (NE.cons FZ (fmap FS xs)) newtype Universe n = Universe { getUniverse :: [Fin n] } newtype Universe1 n = Universe1 { getUniverse1 :: NonEmpty (Fin ('S n)) } -- | @'Fin' 'N.Nat0'@ is not inhabited. absurd :: Fin N.Nat0 -> b absurd n = case n of {} -- | Counting to one is boring. -- -- >>> boring -- 0 boring :: Fin N.Nat1 boring = FZ ------------------------------------------------------------------------------- -- min and max ------------------------------------------------------------------------------- -- | Return a one less. -- -- >>> isMin (FZ :: Fin N.Nat1) -- Nothing -- -- >>> map isMin universe :: [Maybe (Fin N.Nat3)] -- [Nothing,Just 0,Just 1,Just 2] -- -- @since 0.1.1 -- isMin :: Fin ('S n) -> Maybe (Fin n) isMin FZ = Nothing isMin (FS n) = Just n -- | Return a one less. -- -- >>> isMax (FZ :: Fin N.Nat1) -- Nothing -- -- >>> map isMax universe :: [Maybe (Fin N.Nat3)] -- [Just 0,Just 1,Just 2,Nothing] -- -- @since 0.1.1 -- isMax :: forall n. N.InlineInduction n => Fin ('S n) -> Maybe (Fin n) isMax = getIsMax (N.inlineInduction start step) where start :: IsMax 'Z start = IsMax $ \_ -> Nothing step :: IsMax m -> IsMax ('S m) step (IsMax rec) = IsMax $ \n -> case n of FZ -> Just FZ FS m -> fmap FS (rec m) newtype IsMax n = IsMax { getIsMax :: Fin ('S n) -> Maybe (Fin n) } ------------------------------------------------------------------------------- -- Append & Split ------------------------------------------------------------------------------- -- | >>> map weakenRight1 universe :: [Fin N.Nat5] -- [1,2,3,4] -- -- @since 0.1.1 weakenRight1 :: Fin n -> Fin ('S n) weakenRight1 = FS -- | >>> map weakenLeft1 universe :: [Fin N.Nat5] -- [0,1,2,3] -- -- @since 0.1.1 weakenLeft1 :: N.InlineInduction n => Fin n -> Fin ('S n) weakenLeft1 = getWeaken1 (N.inlineInduction start step) where start :: Weaken1 'Z start = Weaken1 absurd step :: Weaken1 n -> Weaken1 ('S n) step (Weaken1 go) = Weaken1 $ \n -> case n of FZ -> FZ FS n' -> FS (go n') newtype Weaken1 n = Weaken1 { getWeaken1 :: Fin n -> Fin ('S n) } -- | >>> map (weakenLeft (Proxy :: Proxy N.Nat2)) (universe :: [Fin N.Nat3]) -- [0,1,2] weakenLeft :: forall n m. N.InlineInduction n => Proxy m -> Fin n -> Fin (N.Plus n m) weakenLeft _ = getWeakenLeft (N.inlineInduction start step :: WeakenLeft m n) where start :: WeakenLeft m 'Z start = WeakenLeft absurd step :: WeakenLeft m p -> WeakenLeft m ('S p) step (WeakenLeft go) = WeakenLeft $ \n -> case n of FZ -> FZ FS n' -> FS (go n') newtype WeakenLeft m n = WeakenLeft { getWeakenLeft :: Fin n -> Fin (N.Plus n m) } -- | >>> map (weakenRight (Proxy :: Proxy N.Nat2)) (universe :: [Fin N.Nat3]) -- [2,3,4] weakenRight :: forall n m. N.InlineInduction n => Proxy n -> Fin m -> Fin (N.Plus n m) weakenRight _ = getWeakenRight (N.inlineInduction start step :: WeakenRight m n) where start = WeakenRight id step (WeakenRight go) = WeakenRight $ \x -> FS $ go x newtype WeakenRight m n = WeakenRight { getWeakenRight :: Fin m -> Fin (N.Plus n m) } -- | Append two 'Fin's together. -- -- >>> append (Left fin2 :: Either (Fin N.Nat5) (Fin N.Nat4)) -- 2 -- -- >>> append (Right fin2 :: Either (Fin N.Nat5) (Fin N.Nat4)) -- 7 -- append :: forall n m. N.InlineInduction n => Either (Fin n) (Fin m) -> Fin (N.Plus n m) append (Left n) = weakenLeft (Proxy :: Proxy m) n append (Right m) = weakenRight (Proxy :: Proxy n) m -- | Inverse of 'append'. -- -- >>> split fin2 :: Either (Fin N.Nat2) (Fin N.Nat3) -- Right 0 -- -- >>> split fin1 :: Either (Fin N.Nat2) (Fin N.Nat3) -- Left 1 -- -- >>> map split universe :: [Either (Fin N.Nat2) (Fin N.Nat3)] -- [Left 0,Left 1,Right 0,Right 1,Right 2] -- split :: forall n m. N.InlineInduction n => Fin (N.Plus n m) -> Either (Fin n) (Fin m) split = getSplit (N.inlineInduction start step) where start :: Split m 'Z start = Split Right step :: Split m p -> Split m ('S p) step (Split go) = Split $ \x -> case x of FZ -> Left FZ FS x' -> bimap FS id $ go x' newtype Split m n = Split { getSplit :: Fin (N.Plus n m) -> Either (Fin n) (Fin m) } ------------------------------------------------------------------------------- -- Aliases ------------------------------------------------------------------------------- fin0 :: Fin (N.Plus N.Nat0 ('S n)) fin1 :: Fin (N.Plus N.Nat1 ('S n)) fin2 :: Fin (N.Plus N.Nat2 ('S n)) fin3 :: Fin (N.Plus N.Nat3 ('S n)) fin4 :: Fin (N.Plus N.Nat4 ('S n)) fin5 :: Fin (N.Plus N.Nat5 ('S n)) fin6 :: Fin (N.Plus N.Nat6 ('S n)) fin7 :: Fin (N.Plus N.Nat7 ('S n)) fin8 :: Fin (N.Plus N.Nat8 ('S n)) fin9 :: Fin (N.Plus N.Nat9 ('S n)) fin0 = FZ fin1 = FS fin0 fin2 = FS fin1 fin3 = FS fin2 fin4 = FS fin3 fin5 = FS fin4 fin6 = FS fin5 fin7 = FS fin6 fin8 = FS fin7 fin9 = FS fin8