{-# LANGUAGE NondecreasingIndentation #-}

-----------------------------------------------------------------------------

-----------------------------------------------------------------------------

-- |
-- Module      :  Disco.Enumerate
-- Copyright   :  disco team and contributors
-- Maintainer  :  byorgey@gmail.com
--
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Enumerate values inhabiting Disco types.
module Disco.Enumerate (
  ValueEnumeration,

  -- * Base types
  enumVoid,
  enumUnit,
  enumBool,
  enumN,
  enumZ,
  enumF,
  enumQ,
  enumC,

  -- * Containers
  enumSet,
  --  , enumBag
  enumList,

  -- * Any type
  enumType,
  enumTypes,

  -- * Lifted functions that return lists
  enumerateType,
  enumerateTypes,
)
where

import qualified Data.Enumeration.Invertible as E
import Disco.AST.Generic (Side (..))
import Disco.Types
import Disco.Value

type ValueEnumeration = E.IEnumeration Value

-- | Enumerate all values of type @Void@ (none).
enumVoid :: ValueEnumeration
enumVoid :: ValueEnumeration
enumVoid = forall a. IEnumeration a
E.void

-- | Enumerate all values of type @Unit@ (the single value @unit@).
enumUnit :: ValueEnumeration
enumUnit :: ValueEnumeration
enumUnit = forall a. a -> IEnumeration a
E.singleton Value
VUnit

-- | Enumerate the values of type @Bool@ as @[false, true]@.
enumBool :: ValueEnumeration
enumBool :: ValueEnumeration
enumBool = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE Side -> Value
toV Value -> Side
fromV forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> IEnumeration a
E.finiteList [Side
L, Side
R]
 where
  toV :: Side -> Value
toV Side
i = Side -> Value -> Value
VInj Side
i Value
VUnit
  fromV :: Value -> Side
fromV (VInj Side
i Value
VUnit) = Side
i
  fromV Value
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"enumBool.fromV: value isn't a bool"

-- | Unsafely extract the numeric value of a @Value@
--   (assumed to be a VNum).
valToRat :: Value -> Rational
valToRat :: Value -> Rational
valToRat (VNum RationalDisplay
_ Rational
r) = Rational
r
valToRat Value
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"valToRat: value isn't a number"

ratToVal :: Rational -> Value
ratToVal :: Rational -> Value
ratToVal = RationalDisplay -> Rational -> Value
VNum forall a. Monoid a => a
mempty

-- | Enumerate all values of type @Nat@ (0, 1, 2, ...).
enumN :: ValueEnumeration
enumN :: ValueEnumeration
enumN = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE (Rational -> Value
ratToVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger) (forall a b. (RealFrac a, Integral b) => a -> b
floor forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> Rational
valToRat) IEnumeration Integer
E.nat

-- | Enumerate all values of type @Integer@ (0, 1, -1, 2, -2, ...).
enumZ :: ValueEnumeration
enumZ :: ValueEnumeration
enumZ = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE (Rational -> Value
ratToVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger) (forall a b. (RealFrac a, Integral b) => a -> b
floor forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> Rational
valToRat) IEnumeration Integer
E.int

-- | Enumerate all values of type @Fractional@ in the Calkin-Wilf
--   order (1, 1/2, 2, 1/3, 3/2, 2/3, 3, ...).
enumF :: ValueEnumeration
enumF :: ValueEnumeration
enumF = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE Rational -> Value
ratToVal Value -> Rational
valToRat IEnumeration Rational
E.cw

-- | Enumerate all values of type @Rational@ in the Calkin-Wilf order,
--   with negatives interleaved (0, 1, -1, 1/2, -1/2, 2, -2, ...).
enumQ :: ValueEnumeration
enumQ :: ValueEnumeration
enumQ = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE Rational -> Value
ratToVal Value -> Rational
valToRat IEnumeration Rational
E.rat

-- | Enumerate all Unicode characters.
enumC :: ValueEnumeration
enumC :: ValueEnumeration
enumC = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE Char -> Value
toV Value -> Char
fromV (forall a. (Enum a, Bounded a) => IEnumeration a
E.boundedEnum @Char)
 where
  toV :: Char -> Value
toV = Rational -> Value
ratToVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum
  fromV :: Value -> Char
fromV = forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (RealFrac a, Integral b) => a -> b
floor forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> Rational
valToRat

-- | Enumerate all *finite* sets over a certain element type, given an
--   enumeration of the elements.  If we think of each finite set as a
--   binary string indicating which elements in the enumeration are
--   members, the sets are enumerated in order of the binary strings.
enumSet :: ValueEnumeration -> ValueEnumeration
enumSet :: ValueEnumeration -> ValueEnumeration
enumSet ValueEnumeration
e = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE [Value] -> Value
toV Value -> [Value]
fromV (forall a. IEnumeration a -> IEnumeration [a]
E.finiteSubsetOf ValueEnumeration
e)
 where
  toV :: [Value] -> Value
toV = [(Value, Integer)] -> Value
VBag forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (,Integer
1)
  fromV :: Value -> [Value]
fromV (VBag [(Value, Integer)]
vs) = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Value, Integer)]
vs
  fromV Value
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"enumSet.fromV: value isn't a set"

-- | Enumerate all *finite* lists over a certain element type, given
--   an enumeration of the elements.  It is very difficult to describe
--   the order in which the lists are generated.
enumList :: ValueEnumeration -> ValueEnumeration
enumList :: ValueEnumeration -> ValueEnumeration
enumList ValueEnumeration
e = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE [Value] -> Value
toV Value -> [Value]
fromV (forall a. IEnumeration a -> IEnumeration [a]
E.listOf ValueEnumeration
e)
 where
  toV :: [Value] -> Value
toV = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Value -> Value -> Value
VCons Value
VNil
  fromV :: Value -> [Value]
fromV (VCons Value
h Value
t) = Value
h forall a. a -> [a] -> [a]
: Value -> [Value]
fromV Value
t
  fromV Value
VNil = []
  fromV Value
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"enumList.fromV: value isn't a list"

-- | Enumerate all functions from a finite domain, given enumerations
--   for the domain and codomain.
enumFunction :: ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumFunction :: ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumFunction ValueEnumeration
xs ValueEnumeration
ys =
  case (forall a. IEnumeration a -> Cardinality
E.card ValueEnumeration
xs, forall a. IEnumeration a -> Cardinality
E.card ValueEnumeration
ys) of
    (E.Finite Integer
0, Cardinality
_) -> forall a. a -> IEnumeration a
E.singleton ((Value -> Value) -> Value
VFun forall a b. (a -> b) -> a -> b
$ \Value
_ -> forall a. HasCallStack => [Char] -> a
error [Char]
"enumFunction: void function called")
    (Cardinality
_, E.Finite Integer
0) -> forall a. IEnumeration a
E.void
    (Cardinality
_, E.Finite Integer
1) -> forall a. a -> IEnumeration a
E.singleton ((Value -> Value) -> Value
VFun forall a b. (a -> b) -> a -> b
$ \Value
_ -> forall a. IEnumeration a -> Integer -> a
E.select ValueEnumeration
ys Integer
0)
    (Cardinality, Cardinality)
_ -> forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE (Value -> Value) -> Value
toV Value -> Value -> Value
fromV (forall a b.
IEnumeration a -> IEnumeration b -> IEnumeration (a -> b)
E.functionOf ValueEnumeration
xs ValueEnumeration
ys)
 where
  -- XXX TODO: better error message on functions with an infinite domain

  toV :: (Value -> Value) -> Value
toV = (Value -> Value) -> Value
VFun
  fromV :: Value -> Value -> Value
fromV (VFun Value -> Value
f) = Value -> Value
f
  fromV Value
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"enumFunction.fromV: value isn't a VFun"

-- | Enumerate all values of a product type, given enumerations of the
--   two component types.  Uses a fair interleaving for infinite
--   component types.
enumProd :: ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumProd :: ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumProd ValueEnumeration
xs ValueEnumeration
ys = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE (Value, Value) -> Value
toV Value -> (Value, Value)
fromV forall a b. (a -> b) -> a -> b
$ forall a b. IEnumeration a -> IEnumeration b -> IEnumeration (a, b)
(E.><) ValueEnumeration
xs ValueEnumeration
ys
 where
  toV :: (Value, Value) -> Value
toV (Value
x, Value
y) = Value -> Value -> Value
VPair Value
x Value
y
  fromV :: Value -> (Value, Value)
fromV (VPair Value
x Value
y) = (Value
x, Value
y)
  fromV Value
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"enumProd.fromV: value isn't a pair"

-- | Enumerate all values of a sum type, given enumerations of the two
--   component types.
enumSum :: ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumSum :: ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumSum ValueEnumeration
xs ValueEnumeration
ys = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE Either Value Value -> Value
toV Value -> Either Value Value
fromV forall a b. (a -> b) -> a -> b
$ forall a b.
IEnumeration a -> IEnumeration b -> IEnumeration (Either a b)
(E.<+>) ValueEnumeration
xs ValueEnumeration
ys
 where
  toV :: Either Value Value -> Value
toV (Left Value
x) = Side -> Value -> Value
VInj Side
L Value
x
  toV (Right Value
y) = Side -> Value -> Value
VInj Side
R Value
y
  fromV :: Value -> Either Value Value
fromV (VInj Side
L Value
x) = forall a b. a -> Either a b
Left Value
x
  fromV (VInj Side
R Value
y) = forall a b. b -> Either a b
Right Value
y
  fromV Value
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"enumSum.fromV: value isn't a sum"

-- | Enumerate the values of a given type.
enumType :: Type -> ValueEnumeration
enumType :: Type -> ValueEnumeration
enumType Type
TyVoid = ValueEnumeration
enumVoid
enumType Type
TyUnit = ValueEnumeration
enumUnit
enumType Type
TyBool = ValueEnumeration
enumBool
enumType Type
TyN = ValueEnumeration
enumN
enumType Type
TyZ = ValueEnumeration
enumZ
enumType Type
TyF = ValueEnumeration
enumF
enumType Type
TyQ = ValueEnumeration
enumQ
enumType Type
TyC = ValueEnumeration
enumC
enumType (TySet Type
t) = ValueEnumeration -> ValueEnumeration
enumSet (Type -> ValueEnumeration
enumType Type
t)
enumType (TyList Type
t) = ValueEnumeration -> ValueEnumeration
enumList (Type -> ValueEnumeration
enumType Type
t)
enumType (Type
a :*: Type
b) = ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumProd (Type -> ValueEnumeration
enumType Type
a) (Type -> ValueEnumeration
enumType Type
b)
enumType (Type
a :+: Type
b) = ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumSum (Type -> ValueEnumeration
enumType Type
a) (Type -> ValueEnumeration
enumType Type
b)
enumType (Type
a :->: Type
b) = ValueEnumeration -> ValueEnumeration -> ValueEnumeration
enumFunction (Type -> ValueEnumeration
enumType Type
a) (Type -> ValueEnumeration
enumType Type
b)
enumType Type
ty = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"enumType: can't enumerate " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Type
ty

-- | Enumerate a finite product of types.
enumTypes :: [Type] -> E.IEnumeration [Value]
enumTypes :: [Type] -> IEnumeration [Value]
enumTypes [] = forall a. a -> IEnumeration a
E.singleton []
enumTypes (Type
t : [Type]
ts) = forall a b.
(a -> b) -> (b -> a) -> IEnumeration a -> IEnumeration b
E.mapE forall {a}. (a, [a]) -> [a]
toL forall {a}. [a] -> (a, [a])
fromL forall a b. (a -> b) -> a -> b
$ forall a b. IEnumeration a -> IEnumeration b -> IEnumeration (a, b)
(E.><) (Type -> ValueEnumeration
enumType Type
t) ([Type] -> IEnumeration [Value]
enumTypes [Type]
ts)
 where
  toL :: (a, [a]) -> [a]
toL (a
x, [a]
xs) = a
x forall a. a -> [a] -> [a]
: [a]
xs
  fromL :: [a] -> (a, [a])
fromL (a
x : [a]
xs) = (a
x, [a]
xs)
  fromL [] = forall a. HasCallStack => [Char] -> a
error [Char]
"enumTypes.fromL: empty list not in enumeration range"

-- | Produce an actual list of the values of a type.
enumerateType :: Type -> [Value]
enumerateType :: Type -> [Value]
enumerateType = forall a. IEnumeration a -> [a]
E.enumerate forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> ValueEnumeration
enumType

-- | Produce an actual list of values enumerated from a finite product
--   of types.
enumerateTypes :: [Type] -> [[Value]]
enumerateTypes :: [Type] -> [[Value]]
enumerateTypes = forall a. IEnumeration a -> [a]
E.enumerate forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> IEnumeration [Value]
enumTypes