/*---------------------------------------------------------------------------*\
 *                                License                                    *
 *                                                                           *
 * This library is free software; you can redistribute it and/or modify it   *
 * under the terms of the GNU Library General Public License as published    *
 * by the Free Software Foundation, version 2.                               *
 *                                                                           *
 * This library is distributed in the hope that it will be useful, but       *
 * WITHOUT ANY WARRANTY; without even the implied warranty of                *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU         *
 * Library General Public License for more details.                          *
 *                                                                           *
 * You should have received a copy of the GNU Library General Public         *
 * License along with this library; if not, write to the Free Software       *
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.                 *
 *                                                                           *
\*---------------------------------------------------------------------------*/

#include <oneapi/tbb/parallel_for.h>
#include <oneapi/tbb/tick_count.h>
#include <nvtt/nvtt.h>

#include "OSGImageFileHandler.h"

#include "OSGImageCompressor.h"

//#define VR_DEBUG_SPEED

OSG_USING_NAMESPACE

// --------------------------------------------------------------------------

class OutputHandler : public nvtt::OutputHandler
{
public:
    OutputHandler() = default;

    void beginImage(int size, int width, int height, int depth, int face, int miplevel) override
    {
    }

    bool writeData(const void *data, int size) override
    {
        std::copy_n(static_cast<const char *>(data), size, _img->getRawData(_mipmap, 0, _side));
        return true;
    }

    void endImage() override
    {
    }

    void setImage(const ImagePtr &img)
    {
        _img = img;
    }

    void setSide(std::uint32_t side)
    {
        _side = side;
    }

    void setMipmap(std::uint32_t mipmap)
    {
        _mipmap = mipmap;
    }

private:

    ImagePtr _img{};
    std::uint32_t _side{};
    std::uint32_t _mipmap{};
};

// --------------------------------------------------------------------------

struct Surface
{
    void setupImageProperties(const ImagePtr &osgImage)
    {
        if(osgImage->getDataType() == Image::OSG_UINT8_IMAGEDATA)
        {
            if( osgImage->hasAttribute(osg::Image::OSG_ATTR_BUMPMAP) || 
                osgImage->hasAttribute(osg::Image::OSG_ATTR_NORMALMAP) || 
                osgImage->hasAttribute(osg::Image::OSG_ATTR_DISPLACEMENTMAP))
            {
                _isSRGB = false;
            }
            else
            {
                _isSRGB = true;
            }
        }
        else
        {
            _isSRGB = false;
        }

        _premultiplyAlpha = !osgImage->getPremultipliedAlpha();
        
    }

    bool init(const ImagePtr &osgImg, UInt32 side, nvtt::Context* context)
    {
        _context = context;
        _surfaces.clear();
        if (osgImg == NullFC)
            return false;
        if(osgImg->getComponents() < 3)
            return false;
        if(osgImg->getDepth() != 1)
            return false;
 
        setupImageProperties(osgImg);
        
        bool ret = true;
        if(osgImg->isCompressed())
        {
            const auto nvttFormat = convertToNvttFormat(osgImg->getInternalPixelFormat());
            const UInt32 mipMapCount = osgImg->getMipMapCount();
            _surfaces.resize(mipMapCount);
            tbb::parallel_for( tbb::blocked_range<int>(0, mipMapCount), [&](tbb::blocked_range<int> r)
            {
                for (std::uint32_t level=r.begin(); level<r.end(); ++level)
                {
                    UInt32 w{};
                    UInt32 h{};
                    UInt32 d{};
                    osgImg->calcMipmapGeometry(level, w, h, d);
                    // setImage2D is really slow for large resolutions with mipmaps. It can take up to 10 seconds for a 15kx15k image.
                    if(!_surfaces[level].setImage2D(nvttFormat, w, h, osgImg->getRawData(level, 0, side)))
                        ret = false;
                }
            });
            
        }
        else
        {
            const auto dataType = osgImg->getDataType();
            if(dataType == Image::OSG_UINT8_IMAGEDATA)
            {
                ret = setImage<osg::UInt8>(osgImg, side);
            }
            else if(dataType == Image::OSG_FLOAT16_IMAGEDATA)
            {
                ret = setImage<osg::Real16>(osgImg, side);
            }
            else if(dataType == Image::OSG_FLOAT32_IMAGEDATA)
            {
                ret = setImage<osg::Real32>(osgImg, side);
            }
        }
        //std::cout << "*** ImageCompressor::init: mipmapcount: " << surface.countMipmaps() << std::endl;
        return ret;
    }

    int width(UInt32 level = 0) const
    {
        if(level >= _surfaces.size())
            return 0;
        return _surfaces[level].width();
    }

    int height(UInt32 level = 0) const
    {
        if(level >= _surfaces.size())
            return 0;
        return _surfaces[level].height();
    }

    int depth(UInt32 level = 0) const
    {
        if(level >= _surfaces.size())
            return 0;
        return _surfaces[level].depth();
    }

    UInt32 getMipmapCount() const
    {
        return _surfaces.size();
    }

    void flipX()
    {
        for(auto &surface : _surfaces)
            surface.flipX();
    }

    void flipY()
    {
        for(auto &surface : _surfaces)
            surface.flipY();
    }

    void rotate90Left()
    {
        for(auto &surface : _surfaces)
            rotate90Left(surface);
    }

    void rotate90Right()
    {
        for(auto &surface : _surfaces)
            rotate90Right(surface);
    }

    void createMipmaps()
    {
        if(_surfaces.empty())
            return;

        if(_surfaces.size() > 1)
        {
            return;
        }
        _surfaces.resize(1);

        if(_context != nullptr && _context->isCudaAccelerationEnabled() && _surfaces[0].gpuData() == nullptr)
        {
            _surfaces[0].ToGPU();
        }

        std::uint32_t i = 0;
        nvtt::Surface ms = _surfaces[i].clone();

        while(ms.canMakeNextMipmap())
        {
            if(_isSRGB)
                ms.toLinearFromSrgb();
            
            if(_premultiplyAlpha)
                ms.premultiplyAlpha();

            ms.buildNextMipmap(nvtt::MipmapFilter_Kaiser);
            
            if(_premultiplyAlpha)
                ms.demultiplyAlpha();

            if(_isSRGB)
                ms.toSrgb();

            _surfaces.emplace_back(ms);
            ++i;
        }

#if 0
        std::cout << "ImageCompressor::createMipmaps: created " << getMipmapCount() << " mipmaps." << std::endl;
        for(auto &surface : _surfaces)
            std::cout << "ImageCompressor::createMipmaps: mipmap: " << surface.width() << " x " << surface.height() << std::endl;
#endif
    }

    bool save(const char * fileName, bool hasAlpha = false, bool hdr = false) const
    {
        if(_surfaces.empty())
            return false;
        return _surfaces[0].save(fileName, hasAlpha, hdr);
    }

    template<typename T>
    bool setImage(const ImagePtr &osgImg, std::uint32_t side)
    {
        setupImageProperties(osgImg);

        bool ret = true;
        const std::uint32_t components = osgImg->getComponents();
        const UInt32 mipMapCount = osgImg->getMipMapCount();
        const bool swapRGB = (osgImg->getPixelFormat() == osg::Image::OSG_BGR_PF || osgImg->getPixelFormat() == osg::Image::OSG_BGRA_PF);
        for (UInt32 level = 0; level < mipMapCount; ++level)
        {
            UInt32 w{};
            UInt32 h{};
            UInt32 d{};
            osgImg->calcMipmapGeometry(level, w, h, d);
            // uncompressed nvtt surface is always in RGBA format!
            const std::uint32_t s = w * h;
            std::vector<T> r(s);
            std::vector<T> g(s);
            std::vector<T> b(s);
            std::vector<T> a(s);

            T alpha{};
            if constexpr (std::is_same_v<T, osg::UInt8>)
            {
                alpha = 255;
            }
            else
            {
                alpha = 1.0f;
            }

            const auto src = reinterpret_cast<T *>(osgImg->getRawData(level, 0, side));

            std::uint32_t j{};
            for (unsigned int i = 0; i < s; ++i)
            {
                r[i] = src[j];
                g[i] = src[j + 1];
                b[i] = src[j + 2];
                if(components == 4)
                    a[i] = src[j + 3];
                else
                    a[i] = alpha;
                j += components;
            }

            const auto nvttInputFormat = convertToNvttInputFormat(osgImg->getDataType());
            nvtt::Surface surface;
            if(swapRGB)
            {
                if(!surface.setImage(nvttInputFormat, w, h, d, b.data(), g.data(), r.data(), a.data()))
                    ret = false;
            }
            else
            {
                if(!surface.setImage(nvttInputFormat, w, h, d, r.data(), g.data(), b.data(), a.data()))
                    ret = false;
            }
            if(_context != nullptr && _context->isCudaAccelerationEnabled() && surface.gpuData() == nullptr)
            {
                surface.ToGPU();
            }
            _surfaces.emplace_back(std::move(surface));
        }
        return ret;
    }

    static void rotate90Left(nvtt::Surface &surface)
    {
        if(surface.gpuData() != nullptr)
            surface.ToCPU();

        int w = surface.width();
        int h = surface.height();

        nvtt::Surface srcSurface = surface.clone();

        auto srcR = srcSurface.channel(0);
        auto srcG = srcSurface.channel(1);
        auto srcB = srcSurface.channel(2);
        auto srcA = srcSurface.channel(3);

        auto dstR = surface.channel(0);
        auto dstG = surface.channel(1);
        auto dstB = surface.channel(2);
        auto dstA = surface.channel(3);

        //for (std::uint32_t y = 0; y < h; ++y)
        tbb::parallel_for( tbb::blocked_range<int>(0,h),
                       [&](tbb::blocked_range<int> r)
        {
            for (std::uint32_t y=r.begin(); y<r.end(); ++y)
            {
                std::uint32_t oldIndex = y * w;
                std::uint32_t newIndex = (h - 1) * w + y;

                for (int x = 0; x < w; ++x)
                {
                    dstR[newIndex] = srcR[oldIndex];
                    dstG[newIndex] = srcG[oldIndex];
                    dstB[newIndex] = srcB[oldIndex];
                    dstA[newIndex] = srcA[oldIndex];

                    ++newIndex;
                    ++oldIndex;

                    newIndex -= (w + 1);
                }
            }
        });
    }

    static void rotate90Right(nvtt::Surface &surface)
    {
        if(surface.gpuData() != nullptr)
            surface.ToCPU();

        int w = surface.width();
        int h = surface.height();

        nvtt::Surface srcSurface = surface.clone();

        auto srcR = srcSurface.channel(0);
        auto srcG = srcSurface.channel(1);
        auto srcB = srcSurface.channel(2);
        auto srcA = srcSurface.channel(3);

        auto dstR = surface.channel(0);
        auto dstG = surface.channel(1);
        auto dstB = surface.channel(2);
        auto dstA = surface.channel(3);

        //for (std::uint32_t y = 0; y < h; ++y)
        tbb::parallel_for( tbb::blocked_range<int>(0,h),
                       [&](tbb::blocked_range<int> r)
        {
            for (std::uint32_t y=r.begin(); y<r.end(); ++y)
            {
                std::uint32_t oldIndex = y * w;
                std::uint32_t newIndex = (w - y - 1);

                for (int x = 0; x < w; ++x)
                {
                    dstR[newIndex] = srcR[oldIndex];
                    dstG[newIndex] = srcG[oldIndex];
                    dstB[newIndex] = srcB[oldIndex];
                    dstA[newIndex] = srcA[oldIndex];

                    ++newIndex;
                    ++oldIndex;

                    newIndex += (w - 1);
                }
            }
        });
    }

    static nvtt::Format convertToNvttFormat(std::uint32_t format)
    {
        switch (format)
        {
            case Image::OSG_RGB_DXT1_PF:
                return nvtt::Format_BC1;
            case Image::OSG_RGBA_DXT1_PF:
                return nvtt::Format_BC1a;
            case Image::OSG_RGBA_DXT3_PF:
                return nvtt::Format_BC2;
            case Image::OSG_RGBA_DXT5_PF:
                return nvtt::Format_BC3;
            case Image::OSG_RGB_BC7_PF:
            case Image::OSG_RGBA_BC7_PF:
                return nvtt::Format_BC7;
            case Image::OSG_RGB_BC6U_PF:
                return nvtt::Format_BC6U;
            case Image::OSG_RGB_BC6S_PF:
                return nvtt::Format_BC6S;
            case Image::OSG_RGB_ASTC_PF:
            case Image::OSG_RGBA_ASTC_PF:
                return nvtt::Format_ASTC_LDR_4x4;
            default:
                return nvtt::Format_RGBA;
        }
        return nvtt::Format_RGBA;
    }

    static nvtt::InputFormat convertToNvttInputFormat(std::uint32_t type)
    {
        switch (type)
        {
            case Image::OSG_UINT8_IMAGEDATA:
                return nvtt::InputFormat_BGRA_8UB;
            case Image::OSG_FLOAT16_IMAGEDATA:
                return nvtt::InputFormat_RGBA_16F;
            case Image::OSG_FLOAT32_IMAGEDATA:
                return nvtt::InputFormat_RGBA_32F;
        }
        return nvtt::InputFormat_BGRA_8UB;
    }

    std::vector<nvtt::Surface> _surfaces{};
    bool _isSRGB{false};
    bool _premultiplyAlpha{false};
    nvtt::Context* _context{nullptr};
};
// --------------------------------------------------------------------------


struct ImageCompressor::SurfaceSet
{
    SurfaceSet() = default;
    ~SurfaceSet()
    {
        if(_context != nullptr)
        {
            delete _context;
        }
    }

    // Convert from osg::Image to nvtt::Surface set
    bool init(const ImagePtr &osgImg)
    {
        if (osgImg == NullFC)
            return false;

#ifdef VR_DEBUG_SPEED
        auto start = std::chrono::high_resolution_clock::now();
#endif
        const std::uint32_t sides = osgImg->getSideCount();
        _surfaces.resize(sides);
        bool ret = true;
        if(!osgImg->isCompressed())
        {
            _context = new nvtt::Context(true);
        }
        tbb::parallel_for( tbb::blocked_range<int>(0, sides),
                       [&](tbb::blocked_range<int> r)
        {
            for (std::uint32_t side=r.begin(); side<r.end(); ++side)
                ret = _surfaces[side].init(osgImg, side, _context);
        });

        if(!ret)
        {
            return false;
        }

#ifdef VR_DEBUG_SPEED
        auto end = std::chrono::high_resolution_clock::now();
        std::cout << "ImageCompressor::init: converting " << (osgImg->isCompressed() ? "and uncompressing " : "") << "osg image to nvtt surface: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << " ms." << std::endl;
#endif

        return true;
    }

    void convertCubemapLayoutFromDXToOGL()
    {
        if(_surfaces.empty())
            return;

#ifdef VR_DEBUG_SPEED
        auto start = std::chrono::high_resolution_clock::now();
#endif

        if (_surfaces.size() == 6)
        {
            convertCubemapDXToOGL(_surfaces);
        }
        else if(_surfaces.size() == 1)
        {
            _surfaces[0].flipY();
        }

#ifdef VR_DEBUG_SPEED
        auto end = std::chrono::high_resolution_clock::now();
        std::cout << "ImageCompressor::convertCubemapLayoutFromDXToOGL: adjusting layout: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << " ms." << std::endl;
#endif
    }

    UInt32 getWidth() const
    {
        if(_surfaces.empty())
            return 0;
        return _surfaces[0].width();
    }

    UInt32 getHeight() const
    {
        if(_surfaces.empty())
            return 0;
        return _surfaces[0].height();
    }

    UInt32 getDepth() const
    {
        if(_surfaces.empty())
            return 0;
        return _surfaces[0].depth();
    }

    UInt32 getSideCount() const
    {
        return _surfaces.size();
    }

    UInt32 getMipmapCount() const
    {
        if(_surfaces.empty())
            return 0;
        return _surfaces[0].getMipmapCount();
    }

    bool convertTo(const ImagePtr &osgImg)
    {
        if(_surfaces.empty())
            return false;

        bool ret = true;
#ifdef VR_DEBUG_SPEED
        auto start = std::chrono::high_resolution_clock::now();
#endif

        if(osgImg->isCompressed() && _context != nullptr)
        {
            nvtt::CompressionOptions compressionOptions;
            compressionOptions.setQuality(nvtt::Quality_Normal);
            const auto nvttFormat = Surface::convertToNvttFormat(osgImg->getInternalPixelFormat());
            compressionOptions.setFormat(nvttFormat);
            //std::cout << "ImageCompressor::convertTo: " << osgImg->getWidth() << " x " << osgImg->getHeight() << " x " << osgImg->getDepth() << " components: " << static_cast<UInt32>(osgImg->getComponents()) << " pixelformat: " << osgImg->getInternalPixelFormat() <<
            // " nvttformat: " << nvttFormat << " dataType: " << osgImg->getDataType() << std::endl;
            const std::uint32_t sides = osgImg->getSideCount();
            const std::uint32_t mipMapCount = osgImg->getMipMapCount();
            auto t0 = tbb::tick_count::now();
            if(_context->isCudaAccelerationEnabled())
            {
                nvtt::BatchList batchList;
                std::vector<nvtt::OutputOptions> outputOptions(sides * mipMapCount);
                std::vector<OutputHandler> outputHandler(sides * mipMapCount);
                size_t idx = 0;
                for(std::uint32_t side = 0; side < sides; ++side)
                {
                    for(UInt32 level = 0; level < mipMapCount; ++level)
                    {
                        if( _surfaces[side]._surfaces[level].gpuData() == nullptr)
                            _surfaces[side]._surfaces[level].ToGPU();

                        outputHandler[idx].setImage(osgImg);
                        outputHandler[idx].setSide(side);
                        outputHandler[idx].setMipmap(level);
                        outputOptions[idx].setOutputHandler(&outputHandler[idx]);
                        batchList.Append(&(_surfaces[side]._surfaces[level]), 0, 0, &outputOptions[idx]);
                        idx++;
                    }
                }
                if(!_context->compress(batchList, compressionOptions))
                    ret = false;
            }
            else
            {
                
                tbb::parallel_for( tbb::blocked_range<int>(0, sides), [&](tbb::blocked_range<int> r)
                {
                    for (std::uint32_t side=r.begin(); side<r.end(); ++side)
                    {
                        for(UInt32 level = 0; level < mipMapCount; ++level)
                        {
                            auto &surface = _surfaces[side]._surfaces[level];
                            nvtt::OutputOptions outputOptions;
                            OutputHandler outputHandler;
                            outputHandler.setImage(osgImg);
                            outputOptions.setOutputHandler(&outputHandler);
                            outputHandler.setSide(side);
                            outputHandler.setMipmap(level);
                            if(!_context->compress(surface, 0, 0, compressionOptions, outputOptions))
                                ret = false;
                        }
                    }
                });
            }            
#ifdef VR_DEBUG_SPEED
            auto end = std::chrono::high_resolution_clock::now();
            std::cout << "ImageCompressor::convertTo: compress nvtt surface and converting nvtt surface to osg image: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << " ms." << std::endl;
#endif
        }
        else
        {
            // we need to convert it back to a osg image
            const auto dataType = osgImg->getDataType();
            const auto components = static_cast<std::uint32_t>(osgImg->getComponents());

            const std::uint32_t sides = osgImg->getSideCount();
            const bool swapRGB = (osgImg->getPixelFormat() == osg::Image::OSG_BGR_PF || osgImg->getPixelFormat() == osg::Image::OSG_BGRA_PF);
            
            for(std::uint32_t side = 0; side < sides; ++side)
            {
                //for(UInt32 level = 0; level < osgImg->getMipMapCount(); ++level)
                tbb::parallel_for( tbb::blocked_range<int>(0, osgImg->getMipMapCount()), [&](tbb::blocked_range<int> l)
                {
                    for (std::uint32_t level = l.begin(); level < l.end(); ++level)
                    {
                        auto &surface = _surfaces[side]._surfaces[level];
                        if(surface.gpuData() != nullptr)
                            surface.ToCPU();
                        const auto dataType = osgImg->getDataType();
                        auto srcR = surface.channel(0);
                        auto srcG = surface.channel(1);
                        auto srcB = surface.channel(2);
                        auto srcA = surface.channel(3);
                        std::uint32_t k=0;
                        const std::uint32_t pixelCount = surface.width() * surface.height() * surface.depth();
                        if(dataType == Image::OSG_UINT8_IMAGEDATA)
                        {
                            auto dst = reinterpret_cast<UInt8 *>(osgImg->getRawData(level, 0, side));
                            tbb::parallel_for( tbb::blocked_range<int>(0, pixelCount), [&](tbb::blocked_range<int> r)
                            {
                                for (std::uint32_t j=r.begin(); j<r.end(); ++j)
                                {
                                    const std::uint32_t k = j * components;
                                    if(swapRGB)
                                    {
                                        dst[k+0] = static_cast<UInt8>(srcB[j] * 255.0f);
                                        dst[k+1] = static_cast<UInt8>(srcG[j] * 255.0f);
                                        dst[k+2] = static_cast<UInt8>(srcR[j] * 255.0f);
                                    } 
                                    else
                                    {
                                        dst[k+0] = static_cast<UInt8>(srcR[j] * 255.0f);
                                        dst[k+1] = static_cast<UInt8>(srcG[j] * 255.0f);
                                        dst[k+2] = static_cast<UInt8>(srcB[j] * 255.0f);
                                    }
                                    if(components == 4)
                                        dst[k+3] = static_cast<UInt8>(srcA[j] * 255.0f);

                                }
                            });
                        }
                        else if(dataType == Image::OSG_FLOAT16_IMAGEDATA)
                        {
                            auto dst = reinterpret_cast<Real16 *>(osgImg->getRawData(level, 0, side));
                            tbb::parallel_for( tbb::blocked_range<int>(0, pixelCount), [&](tbb::blocked_range<int> r)
                            {
                                for (std::uint32_t j=r.begin(); j<r.end(); ++j)
                                {
                                    const std::uint32_t k = j * components;
                                    if(swapRGB)
                                    {
                                        dst[k+0] = static_cast<Real16>(srcB[j]);
                                        dst[k+1] = static_cast<Real16>(srcG[j]);
                                        dst[k+2] = static_cast<Real16>(srcR[j]);
                                    }
                                    else
                                    {
                                        dst[k+0] = static_cast<Real16>(srcR[j]);
                                        dst[k+1] = static_cast<Real16>(srcG[j]);
                                        dst[k+2] = static_cast<Real16>(srcB[j]);
                                    }
                                    if(components == 4)
                                        dst[k+3] = static_cast<Real16>(srcA[j]);
                                }
                            });
                        }
                        else if(dataType == Image::OSG_FLOAT32_IMAGEDATA)
                        {
                            auto dst = reinterpret_cast<Real32 *>(osgImg->getRawData(level, 0, side));
                            tbb::parallel_for( tbb::blocked_range<int>(0, pixelCount), [&](tbb::blocked_range<int> r)
                            {
                                for (std::uint32_t j=r.begin(); j<r.end(); ++j)
                                {
                                    const std::uint32_t k = j * components;
                                    if(swapRGB)
                                    {
                                        dst[k+0] = srcB[j];
                                        dst[k+1] = srcG[j];
                                        dst[k+2] = srcR[j];
                                    }
                                    else
                                    {
                                        dst[k+0] = srcR[j];
                                        dst[k+1] = srcG[j];
                                        dst[k+2] = srcB[j];
                                    }
                                    if(components == 4)
                                        dst[k+3] = srcA[j];
                                }
                            });
                        }
                    }
                });
            }
#ifdef VR_DEBUG_SPEED
            auto end = std::chrono::high_resolution_clock::now();
            std::cout << "ImageCompressor::convertTo: converting nvtt surface to osg image: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << " ms." << std::endl;
#endif
        }
        return ret;
    }

    void write(const std::string& filename) const
    {
        int count{};
        for(auto &surface : _surfaces)
        {
            std::string fullFilename = filename + std::to_string(count++) + ".png";
            surface.save(fullFilename.c_str(), false, true);
        }
    }

    static void convertCubemapDXToOGL(std::vector<Surface> &surfaces)
    {
        if(surfaces.size() != 6)
            return;

        std::swap(surfaces[0], surfaces[1]);
        std::swap(surfaces[2], surfaces[5]);
        std::swap(surfaces[3], surfaces[5]);
        std::swap(surfaces[3], surfaces[4]);

        surfaces[0].rotate90Right();
        surfaces[0].flipX();

        surfaces[1].rotate90Left();
        surfaces[1].flipX();

        surfaces[4].flipX();
        surfaces[2].flipY();
        surfaces[5].flipY();
        surfaces[3].flipX();
    }

    std::vector<Surface> _surfaces{};
    nvtt::Context*       _context{nullptr};
};

// --------------------------------------------------------------------------

ImageCompressor::ImageCompressor() :
    _surfaceSet(std::make_unique<SurfaceSet>())
{
}

ImageCompressor::~ImageCompressor() = default; // Required for forward-declared unique_ptr destruction

ImageCompressor::ImageCompressor(ImageCompressor &&) noexcept = default;
ImageCompressor &ImageCompressor::operator=(ImageCompressor &&) noexcept = default;

bool ImageCompressor::setImage(const ImagePtr &image)
{
    return _surfaceSet->init(image);
}

bool ImageCompressor::getImage(ImagePtr &image, UInt32 pixelFormat) const
{
    if(_surfaceSet->_surfaces.empty())
        return false;

    UInt32 dataType = image->getDataType();
    if(pixelFormat == Image::OSG_RGB_BC6U_PF || pixelFormat == Image::OSG_RGB_BC6S_PF)
        dataType = Image::OSG_FLOAT16_IMAGEDATA;

    image->set(pixelFormat, _surfaceSet->getWidth(), _surfaceSet->getHeight(), _surfaceSet->getDepth(),
               _surfaceSet->getMipmapCount(), 1, 0.0, nullptr, dataType, true, _surfaceSet->getSideCount());
    return _surfaceSet->convertTo(image);
}

bool ImageCompressor::getImage(ImagePtr &image) const
{
    return _surfaceSet->convertTo(image);
}

bool ImageCompressor::getData(std::vector<UInt8>& data, Int32 format, Int32 type) const
{
    if(_surfaceSet->_surfaces.empty())
        return false;

    if(_surfaceSet->_surfaces[0].getMipmapCount() == 0)
        return false;

    bool r = true;
    ImagePtr img = Image::create();
    // first we allocate some memory so we don't get a nullptr on getData()
    data.resize(1024);
    img->setExternalData(data.data());
    img->set(format, _surfaceSet->_surfaces[0].width(), _surfaceSet->_surfaces[0].height(), _surfaceSet->_surfaces[0].depth(), 
            _surfaceSet->_surfaces[0].getMipmapCount(), 1, 0.0f, nullptr, type, false, _surfaceSet->_surfaces.size());
    // now resize it to the correct size
    data.resize(img->getRawSize());
    img->setExternalData(data.data());
    if(!_surfaceSet->convertTo(img))
    {
        FFATAL(("ImageCompressor::convertTo failed!"));
        r = false;
    }
    subRefCP(img);
    return r;
}

bool ImageCompressor::compress(ImagePtr& image)
{
    if(!image->getIsCompressible() || image->isCompressed() || image->isInternal())
        return false;

    const UInt32 channels = image->getComponents();

    // for now
    if(channels < 3)
        return false;

#if 0
    std::string name = image->getName();
    std::string path;
    if(!name.empty())
    {
        path = "E:/EnvironmentTextures/out/" + name + ".png";
        image->setPixelFormat(osg::Image::OSG_RGBA_PF);
        image->write(path.c_str());
    }
#endif

    UInt32 pf = osg::Image::OSG_INVALID_PF;
    if(image->getPixelFormat() == osg::Image::OSG_RGB9E5_PF)
    {
        pf = osg::Image::OSG_RGB_BC6U_PF;
    }
    else if(image->getPixelFormat() == osg::Image::OSG_RGB_PF || image->getPixelFormat() == osg::Image::OSG_BGR_PF)
    {
        if(image->getDataType() == osg::Image::OSG_FLOAT16_IMAGEDATA ||
            image->getDataType() == osg::Image::OSG_FLOAT32_IMAGEDATA)
        {
            pf = osg::Image::OSG_RGB_BC6U_PF;
        }
        else if(image->getDataType() == osg::Image::OSG_UINT8_IMAGEDATA)
        {
            pf = osg::Image::OSG_RGB_BC7_PF;
        }
    }
    else if(image->getPixelFormat() == osg::Image::OSG_RGBA_PF || image->getPixelFormat() == osg::Image::OSG_BGRA_PF)
    {
        if(image->getDataType() == osg::Image::OSG_UINT8_IMAGEDATA)
        {
            pf = osg::Image::OSG_RGBA_BC7_PF;
        }
    }

    if(pf == osg::Image::OSG_INVALID_PF)
        return false;

    if(!setImage(image))
        return false;

    createMipmaps();

#if 0
    if(!name.empty())
    {
        path = "E:/EnvironmentTextures/out/" + name + "-surface.png";
        _surfaceSet->write(path);
    }
#endif

    if(!getImage(image, pf))
        return false;
    
    return true;
}

void ImageCompressor::convertCubemapLayoutFromDXToOGL()
{
    _surfaceSet->convertCubemapLayoutFromDXToOGL();
}

void ImageCompressor::createMipmaps()
{
    if(_surfaceSet->_surfaces.empty())
        return;

    for(auto &surface : _surfaceSet->_surfaces)
        surface.createMipmaps();
}

void ImageCompressor::write(const std::string& filename) const
{
    _surfaceSet->write(filename);
}
