{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}

module Strongweak.Strengthen where

import GHC.TypeNats ( Natural, KnownNat )
import Data.Word
import Data.Int
import Refined ( Refined, refine, Predicate )
import Data.Vector.Sized qualified as Vector
import Data.Vector.Sized ( Vector )
import Type.Reflection ( Typeable, typeRep )

import Prettyprinter
import Prettyprinter.Render.String

import Data.Either.Validation
import Data.List.NonEmpty ( NonEmpty( (:|) ) )
import Data.Foldable qualified as Foldable

{- | Any 'w' can be "strengthened" into an 's' by asserting some properties.

For example, you may strengthen some 'Natural' @n@ into a 'Word8' by asserting
@0 <= n <= 255@.

Note that we restrict strengthened types to having only one corresponding weak
representation using functional dependencies.
-}
class Strengthen w s | s -> w where strengthen :: w -> Validation (NonEmpty StrengthenError) s

-- | 'strengthen' with reordered type variables for more convenient visible type
--   application.
strengthen' :: forall s w. Strengthen w s => w -> Validation (NonEmpty StrengthenError) s
strengthen' :: forall s w.
Strengthen w s =>
w -> Validation (NonEmpty StrengthenError) s
strengthen' = w -> Validation (NonEmpty StrengthenError) s
forall w s.
Strengthen w s =>
w -> Validation (NonEmpty StrengthenError) s
strengthen

-- | Strengthen error data type. Don't use these constructors directly, use
--   the existing helper functions.
--
-- Field indices are from 0 in the respective constructor. Field names are
-- provided if present.
data StrengthenError
  = StrengthenErrorBase
        String -- ^ weak   type
        String -- ^ strong type
        String -- ^ weak value
        String -- ^ msg

  | StrengthenErrorField
        String                      -- ^ weak   datatype name
        String                      -- ^ strong datatype name
        String                      -- ^ weak   constructor name
        String                      -- ^ strong constructor name
        Natural                     -- ^ weak   field index
        (Maybe String)              -- ^ weak   field name (if present)
        Natural                     -- ^ strong field index
        (Maybe String)              -- ^ strong field name (if present)
        (NonEmpty StrengthenError)  -- ^ errors
    deriving stock StrengthenError -> StrengthenError -> Bool
(StrengthenError -> StrengthenError -> Bool)
-> (StrengthenError -> StrengthenError -> Bool)
-> Eq StrengthenError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StrengthenError -> StrengthenError -> Bool
$c/= :: StrengthenError -> StrengthenError -> Bool
== :: StrengthenError -> StrengthenError -> Bool
$c== :: StrengthenError -> StrengthenError -> Bool
Eq

instance Show StrengthenError where
    showsPrec :: Int -> StrengthenError -> ShowS
showsPrec Int
_ = SimpleDocStream Any -> ShowS
forall ann. SimpleDocStream ann -> ShowS
renderShowS (SimpleDocStream Any -> ShowS)
-> (StrengthenError -> SimpleDocStream Any)
-> StrengthenError
-> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LayoutOptions -> Doc Any -> SimpleDocStream Any
forall ann. LayoutOptions -> Doc ann -> SimpleDocStream ann
layoutPretty LayoutOptions
defaultLayoutOptions (Doc Any -> SimpleDocStream Any)
-> (StrengthenError -> Doc Any)
-> StrengthenError
-> SimpleDocStream Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrengthenError -> Doc Any
forall a ann. Pretty a => a -> Doc ann
pretty

-- TODO shorten value if over e.g. 50 chars. e.g. @[0,1,2,...,255] -> FAIL@
instance Pretty StrengthenError where
    pretty :: forall ann. StrengthenError -> Doc ann
pretty = \case
      StrengthenErrorBase String
wt String
st String
wv String
msg ->
        [Doc ann] -> Doc ann
forall ann. [Doc ann] -> Doc ann
vsep [ String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
wtDoc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>Doc ann
"->"Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
st
             , String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
wvDoc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>Doc ann
"->"Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>Doc ann
"FAIL"
             , String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
msg ]
      StrengthenErrorField String
dw String
_ds String
cw String
_cs Natural
iw Maybe String
fw Natural
_is Maybe String
_fs NonEmpty StrengthenError
es ->
        let sw :: String
sw = String -> ShowS -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Natural -> String
forall a. Show a => a -> String
show Natural
iw) ShowS
forall a. a -> a
id Maybe String
fw
        in  Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
nest Int
0 (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$ String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
dwDoc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>Doc ann
"."Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
cwDoc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>Doc ann
"."Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty String
swDoc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>Doc ann
forall ann. Doc ann
lineDoc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<>NonEmpty StrengthenError -> Doc ann
forall a. NonEmpty StrengthenError -> Doc a
strengthenErrorPretty NonEmpty StrengthenError
es

-- mutually recursive with its 'Pretty' instance. safe, but a bit confusing -
-- clean up
strengthenErrorPretty :: NonEmpty StrengthenError -> Doc a
strengthenErrorPretty :: forall a. NonEmpty StrengthenError -> Doc a
strengthenErrorPretty = [Doc a] -> Doc a
forall ann. [Doc ann] -> Doc ann
vsep ([Doc a] -> Doc a)
-> (NonEmpty StrengthenError -> [Doc a])
-> NonEmpty StrengthenError
-> Doc a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StrengthenError -> Doc a) -> [StrengthenError] -> [Doc a]
forall a b. (a -> b) -> [a] -> [b]
map StrengthenError -> Doc a
forall {a} {ann}. Pretty a => a -> Doc ann
go ([StrengthenError] -> [Doc a])
-> (NonEmpty StrengthenError -> [StrengthenError])
-> NonEmpty StrengthenError
-> [Doc a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty StrengthenError -> [StrengthenError]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList
  where go :: a -> Doc ann
go a
e = Doc ann
"-"Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+>Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
indent Int
0 (a -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty a
e)

strengthenErrorBase
    :: forall s w. (Typeable w, Show w, Typeable s)
    => w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase :: forall s w.
(Typeable w, Show w, Typeable s) =>
w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase w
w String
msg = NonEmpty StrengthenError -> Validation (NonEmpty StrengthenError) s
forall e a. e -> Validation e a
Failure (StrengthenError
e StrengthenError -> [StrengthenError] -> NonEmpty StrengthenError
forall a. a -> [a] -> NonEmpty a
:| [])
  where e :: StrengthenError
e = String -> String -> String -> String -> StrengthenError
StrengthenErrorBase (TypeRep w -> String
forall a. Show a => a -> String
show (TypeRep w -> String) -> TypeRep w -> String
forall a b. (a -> b) -> a -> b
$ forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @w) (TypeRep s -> String
forall a. Show a => a -> String
show (TypeRep s -> String) -> TypeRep s -> String
forall a b. (a -> b) -> a -> b
$ forall a. Typeable a => TypeRep a
forall {k} (a :: k). Typeable a => TypeRep a
typeRep @s) (w -> String
forall a. Show a => a -> String
show w
w) String
msg

-- | Strengthen each element of a list.
instance Strengthen w s => Strengthen [w] [s] where
    strengthen :: [w] -> Validation (NonEmpty StrengthenError) [s]
strengthen = (w -> Validation (NonEmpty StrengthenError) s)
-> [w] -> Validation (NonEmpty StrengthenError) [s]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse w -> Validation (NonEmpty StrengthenError) s
forall w s.
Strengthen w s =>
w -> Validation (NonEmpty StrengthenError) s
strengthen

-- | Obtain a sized vector by asserting the size of a plain list.
instance (KnownNat n, Typeable a, Show a) => Strengthen [a] (Vector n a) where
    strengthen :: [a] -> Validation (NonEmpty StrengthenError) (Vector n a)
strengthen [a]
w =
        case [a] -> Maybe (Vector n a)
forall (n :: Natural) a. KnownNat n => [a] -> Maybe (Vector n a)
Vector.fromList [a]
w of
          Maybe (Vector n a)
Nothing -> [a] -> String -> Validation (NonEmpty StrengthenError) (Vector n a)
forall s w.
(Typeable w, Show w, Typeable s) =>
w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase [a]
w String
"TODO bad size vector"
          Just Vector n a
s  -> Vector n a -> Validation (NonEmpty StrengthenError) (Vector n a)
forall e a. a -> Validation e a
Success Vector n a
s

-- | Obtain a refined type by applying its associated refinement.
#ifdef REFINED_POLYKIND
instance (Predicate (p :: k) a, Typeable k, Typeable a, Show a) => Strengthen a (Refined p a) where
#else
instance (Predicate p a, Typeable p, Typeable a, Show a) => Strengthen a (Refined p a) where
#endif
    strengthen :: a -> Validation (NonEmpty StrengthenError) (Refined p a)
strengthen a
a =
        case a -> Either RefineException (Refined p a)
forall p x.
Predicate p x =>
x -> Either RefineException (Refined p x)
refine a
a of
          Left  RefineException
err -> a -> String -> Validation (NonEmpty StrengthenError) (Refined p a)
forall s w.
(Typeable w, Show w, Typeable s) =>
w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase a
a (RefineException -> String
forall a. Show a => a -> String
show RefineException
err)
          Right Refined p a
ra  -> Refined p a -> Validation (NonEmpty StrengthenError) (Refined p a)
forall e a. a -> Validation e a
Success Refined p a
ra

-- Strengthen 'Natural's into Haskell's bounded unsigned numeric types.
instance Strengthen Natural Word8  where strengthen :: Natural -> Validation (NonEmpty StrengthenError) Word8
strengthen = Natural -> Validation (NonEmpty StrengthenError) Word8
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Natural Word16 where strengthen :: Natural -> Validation (NonEmpty StrengthenError) Word16
strengthen = Natural -> Validation (NonEmpty StrengthenError) Word16
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Natural Word32 where strengthen :: Natural -> Validation (NonEmpty StrengthenError) Word32
strengthen = Natural -> Validation (NonEmpty StrengthenError) Word32
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Natural Word64 where strengthen :: Natural -> Validation (NonEmpty StrengthenError) Word64
strengthen = Natural -> Validation (NonEmpty StrengthenError) Word64
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded

-- Strengthen 'Integer's into Haskell's bounded signed numeric types.
instance Strengthen Integer Int8   where strengthen :: Integer -> Validation (NonEmpty StrengthenError) Int8
strengthen = Integer -> Validation (NonEmpty StrengthenError) Int8
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Integer Int16  where strengthen :: Integer -> Validation (NonEmpty StrengthenError) Int16
strengthen = Integer -> Validation (NonEmpty StrengthenError) Int16
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Integer Int32  where strengthen :: Integer -> Validation (NonEmpty StrengthenError) Int32
strengthen = Integer -> Validation (NonEmpty StrengthenError) Int32
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded
instance Strengthen Integer Int64  where strengthen :: Integer -> Validation (NonEmpty StrengthenError) Int64
strengthen = Integer -> Validation (NonEmpty StrengthenError) Int64
forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded

strengthenBounded
    :: forall b n
    .  (Integral b, Bounded b, Show b, Typeable b, Integral n, Show n, Typeable n)
    => n -> Validation (NonEmpty StrengthenError) b
strengthenBounded :: forall b n.
(Integral b, Bounded b, Show b, Typeable b, Integral n, Show n,
 Typeable n) =>
n -> Validation (NonEmpty StrengthenError) b
strengthenBounded n
n =
    if   n
n n -> n -> Bool
forall a. Ord a => a -> a -> Bool
<= n
maxB Bool -> Bool -> Bool
&& n
n n -> n -> Bool
forall a. Ord a => a -> a -> Bool
>= n
minB then b -> Validation (NonEmpty StrengthenError) b
forall e a. a -> Validation e a
Success (n -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral n
n)
    else n -> String -> Validation (NonEmpty StrengthenError) b
forall s w.
(Typeable w, Show w, Typeable s) =>
w -> String -> Validation (NonEmpty StrengthenError) s
strengthenErrorBase n
n (String -> Validation (NonEmpty StrengthenError) b)
-> String -> Validation (NonEmpty StrengthenError) b
forall a b. (a -> b) -> a -> b
$ String
"not well bounded, require: "
                                 String -> ShowS
forall a. Semigroup a => a -> a -> a
<>n -> String
forall a. Show a => a -> String
show n
minBString -> ShowS
forall a. Semigroup a => a -> a -> a
<>String
" <= n <= "String -> ShowS
forall a. Semigroup a => a -> a -> a
<>n -> String
forall a. Show a => a -> String
show n
maxB
  where
    maxB :: n
maxB = forall a b. (Integral a, Num b) => a -> b
fromIntegral @b @n b
forall a. Bounded a => a
maxBound
    minB :: n
minB = forall a b. (Integral a, Num b) => a -> b
fromIntegral @b @n b
forall a. Bounded a => a
minBound