#include <OSGColorConversionFunctions.h>
#include <Imath/ImathVec.h>
#include <Imath/ImathMatrix.h>
#include <ImfChromaticities.h>

#include <oneapi/tbb/parallel_for.h>


Imath::M44f getWhitePointAdaptationMatrix(const Imath::V2f& srcWhite, const Imath::V2f& dstWhite)
{
    // using the imath matrices and vectors since the OpenSG ones make the code awkward
    static const Imath::M44f bradfordCPM(
        0.895100f, -0.750200f, 0.038900f, 0.000000f, 
        0.266400f,  1.713500f, -0.068500f, 0.000000f, 
        -0.161400f, 0.036700f, 1.029600f, 0.000000f,
        0.000000f, 0.000000f, 0.000000f,  1.000000f);

    const static  Imath::M44f inverseBradfordCPM(
        0.986993f, 0.432305f, -0.008529f, 0.000000f,
        -0.147054f, 0.518360f, 0.040043f, 0.000000f,
        0.159963f, 0.049291f, 0.968487f, 0.000000f,
        0.000000f, 0.000000f, 0.000000f, 1.000000f);

    //
    // Convert the white points of the two RGB spaces to XYZ
    //
    float srcX = srcWhite.x;
    float srcY = srcWhite.y;
    Imath::V3f srcNeutralXYZ (srcX / srcY, 1.0f, (1.0f - srcX - srcY) / srcY);

    float dstX = dstWhite.x;
    float dstY = dstWhite.y;
    Imath::V3f  dstNeutralXYZ (dstX / dstY, 1.0f, (1.0f - dstX - dstY) / dstY);

    // Compute the Bradford transformation matrix
    Imath::V3f ratio((dstNeutralXYZ * bradfordCPM) / (srcNeutralXYZ * bradfordCPM));

    Imath::M44f ratioMat( ratio[0], 0.0f, 0.0f, 0.0f, 
                            0.0f, ratio[1], 0.0f, 0.0f, 
                            0.0f, 0.0f, ratio[2], 0.0f, 
                            0.0f, 0.0f, 0.0f, 1.0f);

    return  bradfordCPM * ratioMat * inverseBradfordCPM;
}

osg::Matrix getToLinRec709Matrix(Imf::Chromaticities fileChroma )
{
    static Imf::Chromaticities bt709( Imath::V2f(0.6400, 0.3300), Imath::V2f (0.3000, 0.6000), Imath::V2f (0.1500, 0.0600), Imath::V2f (0.3127, 0.3290));
        
    /*std::cout << "Chroma red: " << fileChroma.red.x << ", " << fileChroma.red.y << std::endl;
    std::cout << "Chroma green: " << fileChroma.green.x << ", " << fileChroma.green.y << std::endl;
    std::cout << "Chroma blue: " << fileChroma.blue.x << ", " << fileChroma.blue.y << std::endl;
    std::cout << "Chroma white: " << fileChroma.white.x << ", " << fileChroma.white.y << std::endl;*/
    // check if the data is already in Rec709. This expects the chromaticities to be exact, otherwise
    // the conversion will be run but the results should be the same in the end
    if (fileChroma.red == bt709.red && fileChroma.green == bt709.green &&
        fileChroma.blue == bt709.blue && fileChroma.white == bt709.white)
    {
        return osg::Matrix::identity();
    }
    auto bradfordTrans = getWhitePointAdaptationMatrix(fileChroma.white, bt709.white);
    auto finalMatrix = RGBtoXYZ (fileChroma, 1.0f) * bradfordTrans * XYZtoRGB (bt709, 1.0f);
    osg::Matrix finalTransform;
    for(int i = 0; i < 4; ++i)
    {
        for(int j = 0; j < 4; ++j)
        {
            finalTransform[i][j] = finalMatrix[i][j];
        }
    }
    return finalTransform;
}

osg::Matrix getToLinRec709Matrix(const osg::Vec2f& red, const osg::Vec2f& green, const osg::Vec2f& blue, const osg::Vec2f& white)
{
    return getToLinRec709Matrix( Imf::Chromaticities( Imath::V2f(red[0], red[1]), Imath::V2f (green[0],green[1]), Imath::V2f (blue[0], blue[1]), Imath::V2f (white[0], white[1])));
}


void convertImage(const osg::Matrix &colorMatrix, const osg::ImagePtr &image)
{
    if(image == osg::NullFC)
        return;
    // no need to convert if the matrix is identity
    if( colorMatrix == osg::Matrix::identity())
        return;

    if(image->getDataType() != osg::Image::OSG_FLOAT16_IMAGEDATA && image->getDataType() != osg::Image::OSG_FLOAT32_IMAGEDATA)
       return;

    uint64_t width = image->getWidth();
    uint64_t height = image->getHeight();
    uint64_t depth = image->getDepth();
    uint64_t sideCount = image->getSideCount();
    
    if(image->getDataType() == osg::Image::OSG_FLOAT16_IMAGEDATA)
    {
        uint64_t numChannelPerLayer = image->getBpp() / sizeof(osg::Real16);
        if(numChannelPerLayer < 3)
            return;
        beginEditCP(image, osg::Image::PixelFieldMask);
        osg::Real16* data = reinterpret_cast<osg::Real16*>(image->editData());
        oneapi::tbb::parallel_for(oneapi::tbb::blocked_range<size_t>(0, width*height*depth*sideCount), [&](const oneapi::tbb::blocked_range<size_t> &range)
        {
            for (auto idx = range.begin(); idx != range.end(); ++idx)
            {
                uint64_t pixelIdx = idx*numChannelPerLayer;
                osg::Vec3f color(data[pixelIdx], data[pixelIdx+1], data[pixelIdx+2]);
                colorMatrix.mult(color);
                data[pixelIdx] = color[0];
                data[pixelIdx+1] = color[1];
                data[pixelIdx+2] = color[2];
            }
        });
        endEditCP(image, osg::Image::PixelFieldMask);
    }
    else if(image->getDataType() == osg::Image::OSG_FLOAT32_IMAGEDATA)
    {
        uint64_t numChannelPerLayer = image->getBpp() / sizeof(float);
        if(numChannelPerLayer < 3)
            return;
        beginEditCP(image, osg::Image::PixelFieldMask);
        float* data = reinterpret_cast<float*>(image->editData());
        oneapi::tbb::parallel_for(oneapi::tbb::blocked_range<size_t>(0, width*height*depth*sideCount), [&](const oneapi::tbb::blocked_range<size_t> &range)
        {
            for (auto idx = range.begin(); idx != range.end(); ++idx)
            {
                uint64_t pixelIdx = idx*numChannelPerLayer;
                osg::Vec3f color(data[pixelIdx], data[pixelIdx+1], data[pixelIdx+2]);
                colorMatrix.mult(color);
                data[pixelIdx] = color[0];
                data[pixelIdx+1] = color[1];
                data[pixelIdx+2] = color[2];
            }
        });
        endEditCP(image, osg::Image::PixelFieldMask);
    }
}