{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE KindSignatures         #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE UndecidableInstances   #-}
{-# LANGUAGE ViewPatterns           #-}

-- | Infrastructure for supporting matching on records
--
-- We are be careful not to reintroduce quadratic code size here.
module Data.Record.QQ.Runtime.MatchHasField (
    MatchHasField -- opaque
  , matchHasField
  , fieldNamed
  , viewAtType
  ) where

import Data.Kind
import GHC.Records.Compat
import GHC.TypeLits

{-------------------------------------------------------------------------------
  Infrastructure
-------------------------------------------------------------------------------}

-- | Pattern match on 'HasField'
--
-- This is intended to be used together with 'matchHasField'. Example usage:
--
-- > data Foo a
-- >
-- > instance HasField "fooX" (Foo a) Int where ..
-- > instance HasField "fooY" (Foo a) [a] where ..
-- >
-- > _example :: Foo Char -> (Int, [Char])
-- > _example (matchHasField -> ( fieldNamed @"fooX" -> x
-- >                            , fieldNamed @"fooY" -> y
-- >                            ) ) = (x, y)
class MatchHasField a b | b -> a where
  matchHasField :: a -> b

-- | To be used in conjunction with 'MatchHasField'.
--
-- See 'MatchHasField' for details.
fieldNamed :: GetField x r a -> a
fieldNamed :: GetField x r a -> a
fieldNamed (GetField a
a) = a
a

data GetField (x :: Symbol) (r :: Type) (a :: Type) = GetField a

instance HasField x r a => MatchHasField r (GetField x r a) where
  matchHasField :: r -> GetField x r a
matchHasField = a -> GetField x r a
forall (x :: Symbol) r a. a -> GetField x r a
GetField (a -> GetField x r a) -> (r -> a) -> r -> GetField x r a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall k (x :: k) r a. HasField x r a => r -> a
forall r a. HasField x r a => r -> a
getField @x)

instance (MatchHasField a b, MatchHasField a c) => MatchHasField a (b, c) where
  matchHasField :: a -> (b, c)
matchHasField a
r = (a -> b
forall a b. MatchHasField a b => a -> b
matchHasField a
r, a -> c
forall a b. MatchHasField a b => a -> b
matchHasField a
r)

-- | Can be used alongside 'matchHasField' to fix the type of the argument
--
-- This avoids inferring types in terms of @HasField ..@; see example below.
viewAtType :: a -> a -> a
viewAtType :: a -> a -> a
viewAtType = (a -> a) -> a -> a -> a
forall a b. a -> b -> a
const a -> a
forall a. a -> a
id

{-------------------------------------------------------------------------------
  Example
-------------------------------------------------------------------------------}

data Foo a = MkFoo

instance HasField "fooX" (Foo a) Int where hasField :: Foo a -> (Int -> Foo a, Int)
hasField = Foo a -> (Int -> Foo a, Int)
forall a. HasCallStack => a
undefined
instance HasField "fooY" (Foo a) [a] where hasField :: Foo a -> ([a] -> Foo a, [a])
hasField = Foo a -> ([a] -> Foo a, [a])
forall a. HasCallStack => a
undefined

_example1 :: (HasField "fooX" a b, HasField "fooY" a c) => a -> (b, c)
_example1 :: a -> (b, c)
_example1 (a -> (GetField "fooX" a b, GetField "fooY" a c)
forall a b. MatchHasField a b => a -> b
matchHasField -> ( forall r a. GetField "fooX" r a -> a
forall (x :: Symbol) r a. GetField x r a -> a
fieldNamed @"fooX" -> b
x
                            , forall r a. GetField "fooY" r a -> a
forall (x :: Symbol) r a. GetField x r a -> a
fieldNamed @"fooY" -> c
y
                            ) ) = (b
x, c
y)

_example2 :: Foo a -> (Int, [a]) -- This is the inferred signature
_example2 :: Foo a -> (Int, [a])
_example2 (Foo a -> Foo a -> Foo a
forall a. a -> a -> a
viewAtType Foo a
forall a. Foo a
MkFoo -> Foo a -> (GetField "fooX" (Foo a) Int, GetField "fooY" (Foo a) [a])
forall a b. MatchHasField a b => a -> b
matchHasField -> ( forall r a. GetField "fooX" r a -> a
forall (x :: Symbol) r a. GetField x r a -> a
fieldNamed @"fooX" -> Int
x
                                                , forall r a. GetField "fooY" r a -> a
forall (x :: Symbol) r a. GetField x r a -> a
fieldNamed @"fooY" -> [a]
y
                                                ) ) = (Int
x, [a]
y)