{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}
module Graphics.Ray.Geometry 
  ( -- * Geometry
    Geometry(Geometry), boundingBox, pureGeometry, transform, moving
    -- * Surfaces and Volumes
  , sphere, planeShape, parallelogram, cuboid, triangle, triangleMesh, constantMedium
    -- * Polygonal Meshes
  , Mesh(Mesh), transformVertices, parseObj, readObj
    -- * Groups
  , group, bvhNode, bvhTree
    -- * Transformations
  , translate, rotateX, rotateY, rotateZ, scale
  ) where

import Graphics.Ray.Core

import Linear (V2(V2), V3(V3), dot, quadrance, (*^), (^/), cross, norm, M44, inv44, (!*), V4(V4))
import qualified Linear.V4 as V4
import Data.Massiv.Array (U, (!))
import qualified Data.Massiv.Array as A
import System.Random (StdGen, random)
import Control.Monad.State (State, state)
import Control.Monad (guard, foldM)
import Control.Applicative ((<|>))
import Data.List (sortOn)
import Data.Bifunctor (first, second)
import Data.Functor.Identity (Identity(Identity), runIdentity)
import Data.Functor ((<&>))
import Text.Read (readMaybe)
import Data.Char (isDigit)
import GHC.Base (inline)

-- | A @'Geometry' m a@ has a bounding box (used in the implementation of bounding volume hierarchies),
-- as well as a function that takes a time, a ray, and an open interval, and in the @m@ monad, produces either @Nothing@
-- (if the ray does not intersect the shape within that interval) or a tuple consisting of a 'HitRecord' and a value of type @a@.
-- The time parameter is between 0 and 1 and allows for motion blur effects.
-- Typically, @m@ is either 'Identity' or @'State' 'StdGen'@, and @a@ is either @()@ or 'Geometry.Material.Material'. 
-- Use the '(<$)' operator to add a material to a geometry.
data Geometry m a = Geometry Box (Double -> Ray -> Interval -> m (Maybe (HitRecord, a)))

instance Functor m => Functor (Geometry m) where
  {-# SPECIALISE fmap :: (a -> b) -> Geometry Identity a -> Geometry Identity b #-}
  fmap :: (a -> b) -> Geometry m a -> Geometry m b
  fmap f (Geometry bbox hit) = Geometry bbox (\ray time ival -> fmap (fmap (second f)) (hit ray time ival))

-- | Promote a pure geometry to a monadic one.
pureGeometry :: Applicative m => Geometry Identity a -> Geometry m a
pureGeometry (Geometry bbox f) = Geometry bbox (\ray time ival -> pure (runIdentity (f ray time ival)))

-- | Get a geometry's bounding box.
boundingBox :: Geometry m a -> Box
boundingBox (Geometry b _) = b

-- | Construct a sphere with the given center and radius.
sphere :: Point3 -> Double -> Geometry Identity ()
sphere center radius = let
  diag = V3 radius radius radius
  bbox = fromCorners (center - diag) (center + diag)

  hitSphere _ (Ray orig dir) bounds = Identity $ do
    let oc = center - orig
    let h = dot dir oc 
    let c = quadrance oc - radius*radius

    let discriminant = h*h - c
    guard (discriminant >= 0)
    
    let sqrtd = sqrt discriminant
    let root1 = h - sqrtd
    let root2 = h + sqrtd

    t <- 
      if inInterval bounds root1 
        then Just root1 
      else if inInterval bounds root2
        then Just root2
      else Nothing
    
    let point = orig + t *^ dir
    let outwardNormal = (point - center) ^/ radius
    let frontSide = dot dir outwardNormal <= 0
    let hit = HitRecord
          { hr_t = t
          , hr_point = point
          , hr_normal = if frontSide then outwardNormal else -outwardNormal
          , hr_frontSide = frontSide
          , hr_uv = sphereUV outwardNormal -- only computed when necessary thanks to laziness
          }
    Just (hit, ())
  
  in Geometry bbox hitSphere

-- [private]
-- With default camera settings (-z direction is forward, +y direction is up),
-- texture images will be wrapped around the sphere starting and ending on the
-- far side of the sphere.
sphereUV :: Vec3 -> V2 Double
sphereUV (V3 x y z) = V2 u v
  where
    u = atan2 x z / (2 * pi) + 0.5
    v = acos (-y) / pi 

-- | Construct a subset of a plane. (See 'parallelogram' and 'triangle' for two instances of this.)
-- Which side is the \"front side\" is determined by the right hand rule.
{-# INLINABLE planeShape #-}
planeShape 
  :: Point3 -- ^ (0, 0) point on the plane
  -> Vec3 -- ^ First basis vector
  -> Vec3 -- ^ Second basis vector
  -> (Double -> Double -> Bool) -- ^ Whether a point on the plane is in the shape
  -> (Double -> Double -> V2 Double) -- ^ Texture coordinates
  -> Box -- ^ Bounding box (this is padded by a small amount to ensure all dimensions are positive)
  -> Geometry Identity ()
planeShape q u v test getUV bbox = let
  cp = cross u v
  norm_cp = norm cp
  normal = cp ^/ norm_cp
  normalS = normal ^/ norm_cp

  hitShape _ (Ray orig dir) bounds = Identity $ do
    let denom = dot normal dir
    guard (abs denom > 1e-8)
    let t = dot normal (q - orig) / denom
    guard (inInterval bounds t)
    let p = orig + t *^ dir
    let p_rel = p - q
    let a = normalS `dot` (p_rel `cross` v)
    let b = normalS `dot` (u `cross` p_rel)
    guard (test a b)
    let frontSide = denom < 0

    let hit = HitRecord
          { hr_t = t
          , hr_point = p
          , hr_normal = if frontSide then normal else -normal
          , hr_frontSide = frontSide
          , hr_uv = getUV a b
          }
    Just (hit, ())

  in Geometry (padBox 0.0001 bbox) hitShape

-- | Construct a parallelogram from a corner point and two edge vectors.
parallelogram :: Point3 -> Vec3 -> Vec3 -> Geometry Identity ()
parallelogram q u v = let
  bbox = boxHull [ q, q + u, q + v, q + u + v ] 
  test a b = 0 <= a && a <= 1 && 0 <= b && b <= 1
  in inline planeShape q u v test V2 bbox

-- | Construct an axis-aligned rectangular cuboid (implemented as a 'group' of parallelograms).
cuboid :: Box -> Geometry Identity ()
cuboid (V3 (xmin, xmax) (ymin, ymax) (zmin, zmax)) = let
  dx = V3 (xmax - xmin) 0 0
  dy = V3 0 (ymax - ymin) 0
  dz = V3 0 0 (zmax - zmin)
  in group 
    [ parallelogram (V3 xmin ymin zmax) dx dy -- front
    , parallelogram (V3 xmax ymin zmin) (-dx) dy -- back
    , parallelogram (V3 xmin ymin zmin) dz dy -- left
    , parallelogram (V3 xmax ymin zmax) (-dz) dy -- right
    , parallelogram (V3 xmin ymax zmax) dx (-dz) -- top
    , parallelogram (V3 xmin ymin zmin) dx dz -- bottom
    ]

-- | Construct a triangle from three corner points and their texture coordinates.
triangle :: (Point3, V2 Double) -> (Point3, V2 Double) -> (Point3, V2 Double) -> Geometry Identity ()
triangle (p0, uv0) (p1, uv1) (p2, uv2) = let
  s1 = p1 - p0
  s2 = p2 - p0
  bbox = boxHull [ p0, p1, p2 ]
  test a b = a >= 0 && b >= 0 && a + b <= 1
  getUV a b = (1 - a - b) *^ uv0 + a *^ uv1 + b *^ uv2
  in inline planeShape p0 s1 s2 test getUV bbox

-- | A collection of triangles.
data Mesh = Mesh 
  (A.Vector U Point3) -- ^ Array of vertex locations.
  (A.Vector U (V2 Double)) -- ^ Array of texture coordinates.
  [V3 (Int, Maybe Int)] -- ^ List of triangles. Each triangle has three vertices, whose locations and (optional) texture coordinates
                        -- are defined by indexing into the two arrays.
  deriving (Show)

-- | Apply an affine transformation (represented as a 4 by 4 matrix whose bottom row is 0 0 0 1) to the vertices of a mesh.
transformVertices :: M44 Double -> Mesh -> Mesh
transformVertices m (Mesh vs vts fs) = 
  let vs' = A.compute (A.map (dropLast . (m !*) . V4.point) vs) in
  Mesh vs' vts fs

-- TODO: use IO exceptions instead of Either?
-- | Parse the .obj file at the given location.
readObj :: FilePath -> IO (Either String Mesh) 
readObj path = first ((path ++ ", ") ++) . parseObj <$> readFile path

-- | Parse a Wavefront .obj file.
--
-- * Comments beginning with @#@ are ignored, as are all lines that do not
--   begin with @v @, @vt @, or @f @. 
-- * Faces with more than three vertices are allowed, but are triangulated before
--   adding them to the 'Mesh' object.
-- * Indices can be either positive or negative. For example, if there are 20
--   @v@ statements in the file, the vertex indices of a face must be in the
--   range [1, 20] or [-20, -1], with -1 meaning the last vertex. In either case,
--   they are converted into 0-based indices.
parseObj :: String -> Either String Mesh
parseObj file = do
  let (vLines, vtLines, fLines) = partitionLines (removeComments file)
  vs <- mapM parseV vLines
  vts <- mapM parseVT vtLines
  fs <- concat <$> mapM (parseF (length vs) (length vts)) fLines
  Right (Mesh (A.fromList A.Seq vs) (A.fromList A.Seq vts) fs)
  
  where
    removeComments :: String -> [String]
    removeComments = map (takeWhile (/= '#')) . lines

    partitionLines :: [String] -> ([(Int, String)], [(Int, String)], [(Int, String)])
    partitionLines = foldr addLine ([], [], []) . zip [1..]

    addLine (k, line) (vs, vts, fs) =
      case line of
        'v':' ':rest -> ((k, rest) : vs, vts, fs)
        'v':'t':' ':rest -> (vs, (k, rest) : vts, fs)
        'f':' ':rest -> (vs, vts, (k, rest) : fs)
        _ -> (vs, vts, fs)

    withLine :: Int -> String -> String
    withLine k err = "line " ++ show k ++ ": " ++ err
    
    -- a 'v' statement must begin with three decimal numbers
    parseV (k, line) = 
      case words line of
        (readMaybe -> Just x) : (readMaybe -> Just y) : (readMaybe -> Just z) : _ -> Right (V3 x y z)
        _ -> Left (withLine k "invalid 'v' statement")
    
    -- a 'vt' statement must begin with two decimal numbers (or consist of a single decimal number, in which case v defaults to 0)
    parseVT (k, line) =
      case words line of
        [readMaybe -> Just u] -> Right (V2 u 0)
        (readMaybe -> Just u) : (readMaybe -> Just v) : _ -> Right (V2 u v)
        _ -> Left (withLine k "invalid 'vt' statement")

    parseF numVs numVTs (k, line) = 
      case mapM (getIndices numVs numVTs) (words line) of
          Left err -> Left (withLine k err)
          Right (i:is) | length is >= 2 -> Right (map (uncurry (V3 i)) (pairs is))
          Right _ -> Left (withLine k "invalid 'f' statement (fewer than 3 vertices)")

    pairs :: [a] -> [(a, a)]
    pairs = \case
      [] -> []
      [_] -> []
      x:xs@(y:_) -> (x, y) : pairs xs

    processIx len i 
      | 1 <= i && i <= len = Right (i - 1)
      | -len <= i && i <= -1 = Right (i + len)
      | otherwise = Left ("index out of bounds: " ++ show i)

    getIndices :: Int -> Int -> String -> Either String (Int, Maybe Int)
    getIndices numVs numVTs str = do
      (i, rest) <- extractInt str
      i' <- processIx numVs i
      case rest of
        "" -> Right (i', Nothing)
        '/':'/':_ -> Right (i', Nothing)
        '/':str' -> do
          (j, _) <- extractInt str'
          j' <- processIx numVTs j
          Right (i', Just j')
        c:_ -> Left ("unexpected character '" ++ c : "'")

    extractInt :: String -> Either String (Int, String)
    extractInt = \case
      '-':str -> first negate <$> extractNat str
      str -> extractNat str
    
    extractNat :: String -> Either String (Int, String)
    extractNat str = 
      let (ds, rest) = span isDigit str in
      case readMaybe ds of
        Nothing -> Left "expected number"
        Just i -> Right (i, rest)

-- | Realize a 'Mesh' as a 'Geometry' (implemented as a 'bvhTree' of triangles).
triangleMesh :: Mesh -> Geometry Identity ()
triangleMesh (Mesh verts uvs tris) = 
  bvhTree $ flip map tris $ \(V3 (i0, j0) (i1, j1) (i2, j2)) -> let
    uv0 = maybe (V2 0 0) (uvs !) j0
    uv1 = maybe (V2 1 0) (uvs !) j1
    uv2 = maybe (V2 0 1) (uvs !) j2
    in triangle (verts ! i0, uv0) (verts ! i1, uv1) (verts ! i2, uv2)

-- | Construct a constant-density medium; useful for subsurface scattering and fog effects.
-- Typical materials are 'Graphics.Material.isotropic', 'Graphics.Material.anisotropic', and 'Graphics.Material.pitchBlack'.
constantMedium
  :: Double -- ^ Density 
  -> Geometry Identity () -- ^ Surface (assumed to be a closed surface with the \"front side\" facing outwards)
  -> Geometry (State StdGen) ()
constantMedium density (Geometry bbox hitObj) = let
  negInvDensity = -(1 / density)

  hitMedium :: Double -> Ray -> Interval -> State StdGen (Maybe (HitRecord, ()))
  hitMedium time ray@(Ray orig dir) (tmin, tmax) = 
    case do (hit1, ()) <- runIdentity (hitObj time ray (tmin, infinity)) 
            if hr_frontSide hit1 
              then do
                guard (hr_t hit1 < tmax)
                (hit2, ()) <- runIdentity (hitObj time ray (hr_t hit1, infinity))
                Just (hr_t hit1, min tmax (hr_t hit2))
              else Just (tmin, min tmax (hr_t hit1))
    of
      Nothing -> pure Nothing -- ray is never in fog within interval
      Just (t1, t2) -> state random <&> \rand ->
         do let inDist = t2 - t1
            let hitDist = negInvDensity * log rand
            guard (hitDist < inDist)
            let t = t1 + hitDist
            let hit = HitRecord
                  { hr_t = t
                  , hr_point = orig + t *^ dir
                  , hr_normal = -dir -- arbitrary
                  , hr_frontSide = True
                  , hr_uv = V2 0 0 -- arbitrary
                  }
            Just (hit, ())

  in Geometry bbox hitMedium

-- | Group multiple geometric objects into a single object. When testing if a ray hits a group, 
-- every constituent of the group is tested without regard to its position. With a large number of objects,
-- use 'bvhTree' for greater efficiency.
{-# SPECIALISE group :: [Geometry Identity a] -> Geometry Identity a #-}
group :: Monad m => [Geometry m a] -> Geometry m a
group obs = let
  bbox = boxJoin (map boundingBox obs)

  hitGroup ray time (tmin, tmax) =
    let try (tmax', knownHit) (Geometry _ hitObj) =
          hitObj ray time (tmin, tmax') <&> \case
            Nothing -> (tmax', knownHit)
            Just (hit, mat) -> (hr_t hit, Just (hit, mat))
    in snd <$> foldM try (tmax, Nothing) obs
  
  in Geometry bbox hitGroup

-- | A single node in a bounding volume hierarchy. Before testing whether a ray hits each child,
-- it tests whether the ray hits a bounding box containing the two children.
{-# SPECIALISE bvhNode :: Geometry Identity a -> Geometry Identity a -> Geometry Identity a #-}
bvhNode :: Monad m => Geometry m a -> Geometry m a -> Geometry m a
bvhNode (Geometry bboxLeft hitLeft) (Geometry bboxRight hitRight) = let
  bbox = boxJoin [bboxLeft, bboxRight]

  hitBvhNode time ray (tmin, tmax)
    | overlapsBox bbox ray (tmin, tmax) = 
      hitLeft time ray (tmin, tmax) >>= \case
        Nothing -> hitRight time ray (tmin, tmax)
        res@(Just (hit, _)) -> fmap (<|> res) (hitRight time ray (tmin, hr_t hit))
    | otherwise = pure Nothing
  
  in Geometry bbox hitBvhNode

-- | Group multiple geometric objects into a single object. 
-- The objects are organized into a tree based on their positions, and then a 'bvhNode'
-- is created for each node in the tree.
{-# SPECIALISE bvhTree :: [Geometry Identity a] -> Geometry Identity a #-}
bvhTree :: Monad m => [Geometry m a] -> Geometry m a
bvhTree = \case
  [] -> error "bvhTree: empty list"
  [obj] -> obj
  obs -> let
    d = longestDim (boxJoin (map boundingBox obs))
    obs' = sortOn (midpoint . component d . boundingBox) obs
    (left, right) = splitAt (length obs `div` 2) obs'
    in bvhNode (bvhTree left) (bvhTree right)

-- | Apply an affine transformation (represented as a 4 by 4 matrix whose bottom row is 0 0 0 1) to a geometric object.
-- The transformation should be Euclidean (translation, rotation, reflection, or a composition thereof); otherwise, the
-- normal vectors of the resulting geometry will be incorrect.
transform :: Functor m => M44 Double -> Geometry m a -> Geometry m a
transform m (Geometry bbox hitObj) = let
  m34 = dropLast m
  inv_m = dropLast (inv44 m)
  corners = map ((m34 !*) . V4.point) (allCorners bbox)
  bbox' = boxHull corners
  in Geometry bbox' $ \time (Ray orig dir) ival ->
    let ray' = Ray (inv_m !* V4.point orig) (inv_m !* V4.vector dir) in
    flip (fmap . fmap . first) (hitObj time ray' ival) $ \hit@(HitRecord {..}) ->
      hit { hr_point = m34 !* V4.point hr_point, hr_normal = m34 !* V4.vector hr_normal }

-- | Translation.
translate :: Vec3 -> M44 Double
translate (V3 x y z) = V4
  (V4 1 0 0 x)
  (V4 0 1 0 y)
  (V4 0 0 1 z)
  (V4 0 0 0 1)

-- | Rotation about the X axis.
rotateX :: Double -> M44 Double
rotateX angle = V4
  (V4 1 0 0 0)
  (V4 0 c (-s) 0)
  (V4 0 s c 0)
  (V4 0 0 0 1)
  where
    c = cos angle
    s = sin angle

-- | Rotation about the Y axis.
rotateY :: Double -> M44 Double
rotateY angle = V4
  (V4 c 0 s 0)
  (V4 0 1 0 0)
  (V4 (-s) 0 c 0)
  (V4 0 0 0 1)
  where 
    c = cos angle
    s = sin angle

-- | Rotation about the Z axis.
rotateZ :: Double -> M44 Double
rotateZ angle = V4
  (V4 c (-s) 0 0)
  (V4 s c 0 0)
  (V4 0 0 1 0)
  (V4 0 0 0 1)
  where
    c = cos angle
    s = sin angle

-- | Scaling centered at the origin. This should not be used with 'transform'.
scale :: Double -> M44 Double
scale a = V4
  (V4 a 0 0 0)
  (V4 0 a 0 0)
  (V4 0 0 a 0)
  (V4 0 0 0 1)

-- [private]
dropLast :: V4 a -> V3 a
dropLast (V4 x y z _) = V3 x y z

-- TODO: this produces incorrect results for objects with solid textures
-- | Create a motion-blurred object that is translated by the first vector at time 0
-- and by the second vector at time 1.
moving :: Functor m => Vec3 -> Vec3 -> Geometry m a -> Geometry m a
moving v0 v1 (Geometry bbox hitObj) =
  let bbox' = boxJoin [shiftBox v0 bbox, shiftBox v1 bbox] in
  Geometry bbox' $ \time (Ray orig dir) ival -> let
    shift = (1 - time) *^ v0 + time *^ v1
    ray' = Ray (orig - shift) dir
    in flip (fmap . fmap . first) (hitObj time ray' ival) $ \hit -> 
      hit { hr_point = hr_point hit + shift }
