module Agda.TypeChecking.SizedTypes.Syntax where
import Data.Foldable (Foldable)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Traversable (Traversable)
import Agda.TypeChecking.SizedTypes.Utils
newtype Offset = O Int
deriving (Eq, Ord, Num, Enum)
instance Show Offset where
show (O n) = show n
instance MeetSemiLattice Offset where
meet = min
instance Plus Offset Offset Offset where
plus (O x) (O y) = O (plus x y)
newtype Rigid = RigidId { rigidId :: String }
deriving (Eq, Ord)
instance Show Rigid where show = rigidId
newtype Flex = FlexId { flexId :: String }
deriving (Eq, Ord)
instance Show Flex where show = flexId
data SizeExpr' rigid flex
= Const { offset :: Offset }
| Rigid { rigid :: rigid, offset :: Offset }
| Infty
| Flex { flex :: flex, offset :: Offset }
deriving (Eq, Ord, Functor, Foldable, Traversable)
type SizeExpr = SizeExpr' Rigid Flex
data Cmp
= Lt
| Le
deriving (Eq, Bounded, Enum)
instance Dioid Cmp where
compose = min
unitCompose = Le
instance Ord Cmp where
Lt <= x = True
Le <= Lt = False
Le <= Le = True
instance MeetSemiLattice Cmp where
meet = min
instance Top Cmp where
top = Le
data Constraint' rigid flex = Constraint
{ leftExpr :: SizeExpr' rigid flex
, cmp :: Cmp
, rightExpr :: SizeExpr' rigid flex
}
deriving (Functor, Foldable, Traversable)
type Constraint = Constraint' Rigid Flex
data Polarity = Least | Greatest
deriving (Eq, Ord)
data PolarityAssignment flex = PolarityAssignment Polarity flex
type Polarities flex = Map flex Polarity
emptyPolarities :: Polarities flex
emptyPolarities = Map.empty
polaritiesFromAssignments :: Ord flex => [PolarityAssignment flex] -> Polarities flex
polaritiesFromAssignments = Map.fromList . map (\ (PolarityAssignment p x) -> (x,p))
getPolarity :: Ord flex => Polarities flex -> flex -> Polarity
getPolarity pols x = Map.findWithDefault Least x pols
type Solution rigid flex = Map flex (SizeExpr' rigid flex)
class Substitute r f a where
subst :: Solution r f -> a -> a
instance Ord f => Substitute r f (SizeExpr' r f) where
subst sol e =
case e of
Flex x n -> Map.findWithDefault e x sol `plus` n
_ -> e
instance Ord f => Substitute r f (Constraint' r f) where
subst sol (Constraint e cmp e') = Constraint (subst sol e) cmp (subst sol e')
instance Substitute r f a => Substitute r f [a] where
subst = map . subst
instance Substitute r f a => Substitute r f (Map k a) where
subst = fmap . subst
instance Plus (SizeExpr' r f) Offset (SizeExpr' r f) where
plus e m =
case e of
Const n -> Const $ n + m
Rigid i n -> Rigid i $ n + m
Flex x n -> Flex x $ n + m
Infty -> Infty
type CTrans r f = Constraint' r f -> Either String [Constraint' r f]
simplify1 :: (Show f, Show r, Eq r) => CTrans r f -> CTrans r f
simplify1 test c = do
let err = Left $ "size constraint " ++ show c ++ " is inconsistent"
case c of
Constraint a Le Infty -> return []
Constraint Const{} Lt Infty -> return []
Constraint Infty Lt Infty -> err
Constraint (Rigid i n) Lt Infty -> test $ Constraint (Rigid i 0) Lt Infty
Constraint a@Flex{} Lt Infty -> return [c { leftExpr = a { offset = 0 }}]
Constraint (Const n) cmp (Const m) -> if compareOffset n cmp m then return [] else err
Constraint Infty cmp Const{} -> err
Constraint (Rigid i n) cmp (Const m) ->
if compareOffset n cmp m then
test (Constraint (Rigid i 0) Le (Const (m n ifLe cmp 0 1)))
else err
Constraint (Flex x n) cmp (Const m) ->
if compareOffset n cmp m
then return [Constraint (Flex x 0) Le (Const (m n ifLe cmp 0 1))]
else err
Constraint Infty cmp Rigid{} -> err
Constraint (Const m) cmp (Rigid i n) ->
if compareOffset m cmp n then return []
else test (Constraint (Const $ m n) cmp (Rigid i 0))
Constraint (Rigid j m) cmp (Rigid i n) | i == j ->
if compareOffset m cmp n then return [] else err
Constraint (Rigid j m) cmp (Rigid i n) -> test c
Constraint (Flex x m) cmp (Rigid i n) ->
if compareOffset m cmp n
then return [Constraint (Flex x 0) Le (Rigid i (n m ifLe cmp 0 1))]
else return [Constraint (Flex x $ m n + ifLe cmp 0 1) Le (Rigid i 0)]
Constraint Infty Le (Flex x n) -> return [Constraint Infty Le (Flex x 0)]
Constraint Infty Lt (Flex x n) -> err
Constraint (Const m) cmp (Flex x n) ->
if compareOffset m cmp n then return []
else return [Constraint (Const $ m n + ifLe cmp 0 1) Le (Flex x 0)]
Constraint (Rigid i m) cmp (Flex x n) ->
if compareOffset m cmp n
then return [Constraint (Rigid i 0) cmp (Flex x $ n m)]
else return [Constraint (Rigid i $ m n) cmp (Flex x 0)]
Constraint (Flex y m) cmp (Flex x n) ->
if compareOffset m cmp n
then return [Constraint (Flex y 0) cmp (Flex x $ n m)]
else return [Constraint (Flex y $ m n) cmp (Flex x 0)]
ifLe :: Cmp -> a -> a -> a
ifLe Le a b = a
ifLe Lt a b = b
compareOffset :: Offset -> Cmp -> Offset -> Bool
compareOffset n Le m = n <= m
compareOffset n Lt m = n < m
instance (Show r, Show f) => Show (SizeExpr' r f) where
show (Const n) = show n
show (Infty) = "∞"
show (Rigid i 0) = show i
show (Rigid i n) = show i ++ "+" ++ show n
show (Flex x 0) = show x
show (Flex x n) = show x ++ "+" ++ show n
instance Show Polarity where
show Least = "-"
show Greatest = "+"
instance Show flex => Show (PolarityAssignment flex) where
show (PolarityAssignment pol flex) = show pol ++ show flex
instance Show Cmp where
show Le = "≤"
show Lt = "<"
instance (Show r, Show f) => Show (Constraint' r f) where
show (Constraint a cmp b) = show a ++ " " ++ show cmp ++ " " ++ show b
class ValidOffset a where
validOffset :: a -> Bool
instance ValidOffset Offset where
validOffset = (>= 0)
instance ValidOffset (SizeExpr' r f) where
validOffset e =
case e of
Infty -> True
_ -> validOffset (offset e)
class TruncateOffset a where
truncateOffset :: a -> a
instance TruncateOffset Offset where
truncateOffset n | n >= 0 = n
| otherwise = 0
instance TruncateOffset (SizeExpr' r f) where
truncateOffset e =
case e of
Infty -> e
Const n -> Const $ truncateOffset n
Rigid i n -> Rigid i $ truncateOffset n
Flex x n -> Flex x $ truncateOffset n
class Rigids r a where
rigids :: a -> Set r
instance (Ord r, Rigids r a) => Rigids r [a] where
rigids as = Set.unions (map rigids as)
instance Rigids r (SizeExpr' r f) where
rigids (Rigid x _) = Set.singleton x
rigids _ = Set.empty
instance Ord r => Rigids r (Constraint' r f) where
rigids (Constraint l _ r) = Set.union (rigids l) (rigids r)
class Flexs flex a | a -> flex where
flexs :: a -> Set flex
instance (Ord flex, Flexs flex a) => Flexs flex [a] where
flexs as = Set.unions (map flexs as)
instance Flexs flex (SizeExpr' rigid flex) where
flexs (Flex x _) = Set.singleton x
flexs _ = Set.empty
instance (Ord flex) => Flexs flex (Constraint' rigid flex) where
flexs (Constraint l _ r) = Set.union (flexs l) (flexs r)