# copyright (c) 2018- polygoniq xyz s.r.o.

import typing
import collections
import bpy
import bpy_extras.object_utils
import logging
logger = logging.getLogger(__name__)


if "polib" not in locals():
    import polib
    from . import utils
else:
    import importlib
    polib = importlib.reload(polib)
    utils = importlib.reload(utils)


def get_object_2d_bounds(
    scene: bpy.types.Scene,
    camera: bpy.types.Object,
    obj: bpy.types.Object
) -> typing.Tuple[float, float]:
    """Returns width and height pixels of rectangular bounding box in camera 2D space

    Result can be bigger than render resolution as it's not clipped to camera plane. Returns zeros
    if object's 3D bounding box is not visible from camera.
    """

    logger.debug(
        f"Asked for 2D bounds of object {obj.name} in scene {scene.name} using camera {camera.name}.")

    bounding_box = polib.linalg.AlignedBox()
    bounding_box.extend_by_object(obj)

    min_x = float("inf")
    max_x = float("-inf")
    min_y = float("inf")
    max_y = float("-inf")
    min_depth = float("inf")
    max_depth = float("-inf")

    for corner in bounding_box.get_corners():
        # get the 2D pixel coordinates of the bounding box corner. NDC = normalized device coords
        corner_ndc_x, corner_ndc_y, corner_depth = \
            bpy_extras.object_utils.world_to_camera_view(scene, camera, corner)

        # convert NDC into pixel values
        corner_2d_x = corner_ndc_x * scene.render.resolution_x
        corner_2d_y = (1.0 - corner_ndc_y) * scene.render.resolution_y

        min_x = min(min_x, corner_2d_x)
        max_x = max(max_x, corner_2d_x)
        min_y = min(min_y, corner_2d_y)
        max_y = max(max_y, corner_2d_y)
        min_depth = min(min_depth, corner_depth)
        max_depth = max(max_depth, corner_depth)

    if max_depth < 0:
        # the object is entirely behind the camera
        logger.debug(
            f"Object {obj.name} 2D bounds ended up as 0, 0 because it's entirely behind the "
            f"camera. min_depth: {min_depth}, max_depth: {max_depth}.")
        return 0, 0

    if max_x < 0:
        # the object is entirely left of the frustum
        logger.debug(
            f"Object {obj.name} 2D bounds ended up as 0, 0 because it's entirely left of the "
            f"frustum of the camera. min_x: {min_x}, max_x: {max_x}.")
        return 0, 0

    if min_x > bpy.context.scene.render.resolution_x:
        # the object is entirely left of the frustum
        logger.debug(
            f"Object {obj.name} 2D bounds ended up as 0, 0 because it's entirely right of the "
            f"frustum of the camera. min_x: {min_x}, max_x: {max_x}.")
        return 0, 0

    if max_y < 0:
        # the object is entirely under the frustum
        logger.debug(
            f"Object {obj.name} 2D bounds ended up as 0, 0 because it's entirely down under the "
            f"frustum of the camera. min_y: {min_y}, max_y: {max_y}.")
        return 0, 0

    if min_y > bpy.context.scene.render.resolution_y:
        # the object is entirely over the frustum
        logger.debug(
            f"Object {obj.name} 2D bounds ended up as 0, 0 because it's entirely up over the "
            f"frustum of the camera. min_y: {min_y}, max_y: {max_y}.")
        return 0, 0

    size_x = max_x - min_x
    size_y = max_y - min_y
    assert size_x >= 0
    assert size_y >= 0

    logger.debug(
        f"Object {obj.name} 2D bounds ended up as {size_x}, {size_y}. "
        f"min_x: {min_x}, max_x: {max_x}, min_y: {min_y}, max_y: {max_y}, "
        f"min_depth: {min_depth}, max_depth: {max_depth}.")
    return size_x, size_y


def get_size_map_for_objects(
    scene: bpy.types.Scene,
    camera: bpy.types.Object,
    objects: typing.Iterable[bpy.types.Object],
    size_factor: float,
    min_size: int,
    max_size: int,
    size_pot_only: bool = True
) -> typing.DefaultDict[bpy.types.Image, int]:
    # 0 in the dictionary means we want the original
    ret: typing.DefaultDict[bpy.types.Image, int] = collections.defaultdict(lambda: 1)

    assert max_size >= min_size

    for obj in objects:
        size_x, size_y = get_object_2d_bounds(scene, camera, obj)
        size_max = max(size_x, size_y)
        side_size: int = round(size_max * size_factor)

        assert side_size >= 0
        side_size = max(side_size, min_size)
        side_size = min(side_size, max_size)
        if side_size == 0:
            side_size = 1  # we can't scale images to 0

        if size_pot_only:
            side_size = 1 << (side_size - 1).bit_length()
            if side_size < 32 and side_size > 1:
                # don't generate any sizes between 1 and 32, it's wasted files
                side_size = 32

        for image in utils.get_images_used_in_object(obj):
            if ret[image] < side_size:
                logger.debug(
                    f"Upgrading image {image.name} from size {ret[image]} to {side_size} because "
                    f"of its usage in object {obj.name}, 2D bounds of object: {size_x}, {size_y}."
                )
                # Only update the new size if it is smaller than the original image size
                ret[image] = min(side_size, max(image.size[0], image.size[1]))

    return ret
