module Agda.TypeChecking.SizedTypes.Syntax where
import Data.Foldable (Foldable)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Traversable (Traversable)
import Test.QuickCheck
import Agda.TypeChecking.SizedTypes.Utils
newtype Offset = O Int
deriving (Eq, Ord, Num, Show, Enum)
instance MeetSemiLattice Offset where
meet = min
instance Plus Offset Offset Offset where
plus (O x) (O y) = O (plus x y)
instance Arbitrary Offset where
arbitrary = do
NonNegative x <- arbitrary
return x
shrink (O x) = map O $ filter (>= 0) (shrink x)
newtype Rigid = RigidId { rigidId :: String }
deriving (Eq, Ord)
instance Show Rigid where show = rigidId
newtype Flex = FlexId { flexId :: String }
deriving (Eq, Ord)
instance Show Flex where show = flexId
data SizeExpr' rigid flex
= Const { offset :: Offset }
| Rigid { rigid :: rigid, offset :: Offset }
| Infty
| Flex { flex :: flex, offset :: Offset }
deriving (Eq, Ord, Functor, Foldable, Traversable)
type SizeExpr = SizeExpr' Rigid Flex
data Cmp
= Lt
| Le
deriving (Eq, Bounded, Enum)
instance Dioid Cmp where
compose = min
unitCompose = Le
instance Ord Cmp where
Lt <= x = True
Le <= Lt = False
Le <= Le = True
instance MeetSemiLattice Cmp where
meet = min
instance Top Cmp where
top = Le
instance Arbitrary Cmp where
arbitrary = arbitraryBoundedEnum
data Constraint' rigid flex = Constraint
{ leftExpr :: SizeExpr' rigid flex
, cmp :: Cmp
, rightExpr :: SizeExpr' rigid flex
}
deriving (Functor, Foldable, Traversable)
type Constraint = Constraint' Rigid Flex
data Polarity = Least | Greatest
deriving (Eq, Ord)
data PolarityAssignment flex = PolarityAssignment Polarity flex
type Polarities flex = Map flex Polarity
emptyPolarities :: Polarities flex
emptyPolarities = Map.empty
polaritiesFromAssignments :: Ord flex => [PolarityAssignment flex] -> Polarities flex
polaritiesFromAssignments = Map.fromList . map (\ (PolarityAssignment p x) -> (x,p))
getPolarity :: Ord flex => Polarities flex -> flex -> Polarity
getPolarity pols x = Map.findWithDefault Least x pols
type Solution rigid flex = Map flex (SizeExpr' rigid flex)
class Substitute r f a where
subst :: Solution r f -> a -> a
instance Ord f => Substitute r f (SizeExpr' r f) where
subst sol e =
case e of
Flex x n -> Map.findWithDefault e x sol `plus` n
_ -> e
instance Ord f => Substitute r f (Constraint' r f) where
subst sol (Constraint e cmp e') = Constraint (subst sol e) cmp (subst sol e')
instance Substitute r f a => Substitute r f [a] where
subst = map . subst
instance Plus (SizeExpr' r f) Offset (SizeExpr' r f) where
plus e m =
case e of
Const n -> Const $ n + m
Rigid i n -> Rigid i $ n + m
Flex x n -> Flex x $ n + m
Infty -> Infty
type CTrans r f = Constraint' r f -> Maybe [Constraint' r f]
simplify1 :: Eq r => CTrans r f-> CTrans r f
simplify1 test c =
case c of
Constraint a Le Infty -> Just []
Constraint Const{} Lt Infty -> Just []
Constraint Infty Lt Infty -> Nothing
Constraint (Rigid i n) Lt Infty -> test $ Constraint (Rigid i 0) Lt Infty
Constraint a@Flex{} Lt Infty -> Just [c { leftExpr = a { offset = 0 }}]
Constraint (Const n) cmp (Const m) -> if compareOffset n cmp m then Just [] else Nothing
Constraint Infty cmp Const{} -> Nothing
Constraint (Rigid i n) cmp (Const m) ->
if compareOffset n cmp m then
test (Constraint (Rigid i 0) Le (Const (m n ifLe cmp 0 1)))
else Nothing
Constraint (Flex x n) cmp (Const m) ->
if compareOffset n cmp m
then Just [Constraint (Flex x 0) Le (Const (m n ifLe cmp 0 1))]
else Nothing
Constraint Infty cmp Rigid{} -> Nothing
Constraint (Const m) cmp (Rigid i n) ->
if compareOffset m cmp n then Just []
else test (Constraint (Const $ m n) cmp (Rigid i 0))
Constraint (Rigid j m) cmp (Rigid i n) | i == j ->
if compareOffset m cmp n then Just [] else Nothing
Constraint (Rigid j m) cmp (Rigid i n) -> test c
Constraint (Flex x m) cmp (Rigid i n) ->
if compareOffset m cmp n
then Just [Constraint (Flex x 0) Le (Rigid i (n m ifLe cmp 0 1))]
else Just [Constraint (Flex x $ m n + ifLe cmp 0 1) Le (Rigid i 0)]
Constraint Infty Le (Flex x n) -> Just [Constraint Infty Le (Flex x 0)]
Constraint Infty Lt (Flex x n) -> Nothing
Constraint (Const m) cmp (Flex x n) ->
if compareOffset m cmp n then Just []
else Just [Constraint (Const $ m n + ifLe cmp 0 1) Le (Flex x 0)]
Constraint (Rigid i m) cmp (Flex x n) ->
if compareOffset m cmp n
then Just [Constraint (Rigid i 0) cmp (Flex x $ n m)]
else Just [Constraint (Rigid i $ m n) cmp (Flex x 0)]
Constraint (Flex y m) cmp (Flex x n) ->
if compareOffset m cmp n
then Just [Constraint (Flex y 0) cmp (Flex x $ n m)]
else Just [Constraint (Flex y $ m n) cmp (Flex x 0)]
ifLe :: Cmp -> a -> a -> a
ifLe Le a b = a
ifLe Lt a b = b
compareOffset :: Offset -> Cmp -> Offset -> Bool
compareOffset n Le m = n <= m
compareOffset n Lt m = n < m
instance (Show r, Show f) => Show (SizeExpr' r f) where
show (Const n) = show n
show (Infty) = "∞"
show (Rigid i 0) = show i
show (Rigid i n) = show i ++ "+" ++ show n
show (Flex x 0) = show x
show (Flex x n) = show x ++ "+" ++ show n
instance Show Polarity where
show Least = "-"
show Greatest = "+"
instance Show flex => Show (PolarityAssignment flex) where
show (PolarityAssignment pol flex) = show pol ++ show flex
instance Show Cmp where
show Le = "≤"
show Lt = "<"
instance (Show r, Show f) => Show (Constraint' r f) where
show (Constraint a cmp b) = show a ++ show cmp ++ show b
class ValidOffset a where
validOffset :: a -> Bool
instance ValidOffset Offset where
validOffset = (>= 0)
instance ValidOffset (SizeExpr' r f) where
validOffset e =
case e of
Infty -> True
_ -> validOffset (offset e)
class TruncateOffset a where
truncateOffset :: a -> a
instance TruncateOffset Offset where
truncateOffset n | n >= 0 = n
| otherwise = 0
instance TruncateOffset (SizeExpr' r f) where
truncateOffset e =
case e of
Infty -> e
Const n -> Const $ truncateOffset n
Rigid i n -> Rigid i $ truncateOffset n
Flex x n -> Flex x $ truncateOffset n
class Rigids r a where
rigids :: a -> Set r
instance (Ord r, Rigids r a) => Rigids r [a] where
rigids as = Set.unions (map rigids as)
instance Rigids r (SizeExpr' r f) where
rigids (Rigid x _) = Set.singleton x
rigids _ = Set.empty
instance Ord r => Rigids r (Constraint' r f) where
rigids (Constraint l _ r) = Set.union (rigids l) (rigids r)
class Flexs flex a | a -> flex where
flexs :: a -> Set flex
instance (Ord flex, Flexs flex a) => Flexs flex [a] where
flexs as = Set.unions (map flexs as)
instance (Ord flex) => Flexs flex (SizeExpr' rigid flex) where
flexs (Flex x _) = Set.singleton x
flexs _ = Set.empty
instance (Ord flex) => Flexs flex (Constraint' rigid flex) where
flexs (Constraint l _ r) = Set.union (flexs l) (flexs r)