Module narya.preprocessing.image

Expand source code
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import mxnet as mx
import numpy as np
import cv2
import segmentation_models as sm
from gluoncv.data.transforms.presets.ssd import transform_test
from ..utils.image import (
    np_img_to_torch_img,
    normalize_single_image_torch,
    torch_img_to_np_img,
)

"""Builds the preprocessing function for each model. They all use torch/keras/gluoncv functions depending on the model.
Arguments:
    input_shape: Tuple of integer, the input_shape the model needs to take
Returns:
    preprocessing: function that takes an image as input, and returns the preprocessed image.
"""


def _build_reid_preprocessing(input_shape):
    """Builds the preprocessing function for the Re Identification model.

    """

    def preprocessing(input_img, **kwargs):

        to_normalize = True if np.percentile(input_img, 98) > 1.0 else False

        if len(input_img.shape) == 4:
            print(
                "Only preprocessing single image, we will consider the first one of the batch"
            )
            image = input_img[0] / 255.0 if to_normalize else input_img[0] / 1.0
        else:
            image = input_img / 255.0 if to_normalize else input_img / 1.0

        image = cv2.resize(image, (input_shape[1], input_shape[0]))

        image = np_img_to_torch_img(image)
        image = normalize_single_image_torch(
            image, img_mean=[0.485, 0.456, 0.406], img_std=[0.229, 0.224, 0.225]
        )
        return image.expand(1, -1, -1, -1)

    return preprocessing


def _build_keypoint_preprocessing(input_shape, backbone):
    """Builds the preprocessing function for the Field Keypoint Detector Model.

    """
    sm_preprocessing = sm.get_preprocessing(backbone)

    def preprocessing(input_img, **kwargs):

        to_normalize = False if np.percentile(input_img, 98) > 1.0 else True

        if len(input_img.shape) == 4:
            print(
                "Only preprocessing single image, we will consider the first one of the batch"
            )
            image = input_img[0] * 255.0 if to_normalize else input_img[0] * 1.0
        else:
            image = input_img * 255.0 if to_normalize else input_img * 1.0

        image = cv2.resize(image, input_shape)
        image = sm_preprocessing(image)
        return image

    return preprocessing


def _build_tracking_preprocessing(input_shape):
    """Builds the preprocessing function for the Player/Ball Tracking Model.

    """

    def preprocessing(input_img, **kwargs):

        to_normalize = False if np.percentile(input_img, 98) > 1.0 else True

        if len(input_img.shape) == 4:
            print(
                "Only preprocessing single image, we will consider the first one of the batch"
            )
            image = input_img[0] * 255.0 if to_normalize else input_img[0] * 1.0
        else:
            image = input_img * 255.0 if to_normalize else input_img * 1.0

        image = cv2.resize(image, input_shape)
        x, _ = transform_test(mx.nd.array(image), min(input_shape))
        return x

    return preprocessing


def _build_homo_preprocessing(input_shape):
    """Builds the preprocessing function for the Deep Homography estimation Model.

    """

    def preprocessing(input_img, **kwargs):

        if len(input_img.shape) == 4:
            print(
                "Only preprocessing single image, we will consider the first one of the batch"
            )
            image = input_img[0]
        else:
            image = input_img

        image = cv2.resize(image, input_shape)
        image = torch_img_to_np_img(
            normalize_single_image_torch(np_img_to_torch_img(image))
        )
        return image

    return preprocessing