//
// Copyright 2018 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// ConvertVertex.comp: vertex buffer conversion.  Implements functionality in copyvertex.inc.
//
// Each thread of the dispatch call fills in one 4-byte element, no matter how many components
// fit in it.  The src data is laid out in the most general form as follows.  Note that component
// size is assumed to divide buffer stride.
//
//    Ns components, each Bs bytes
//         ____^_____
//        /          |
//       +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
//       |C1|C2|..|CN|..|..|..|..|C1|C2|..|CN|..|..|..|..|C1|C2|..|CN| ... Repeated V times
//       +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
//        \__________ __________/
//                   V
//           Ss bytes of stride
//
// The output is the array of components converted to the destination format (each Bd bytes) with
// stride Sd = Nd*Bd (i.e. packed).  The output size is therefore V*Nd*Bd bytes.  The dispatch size
// is accordingly ciel(V*Nd*Bd / 4).
//
// The input is received in 4-byte elements, therefore each element has Es=4/Bs components.
//
// To output exactly one 4-byte element, each thread is responsible for Ed=4/Bd components.
// Therefore, thread t is responsible for component indices [Ed*t, Ed*(t + 1)).
//
// We don't use Bs and Es for A2B10G10R10 and R10G10B10A2 formats since they take 10 or 2 bits per
// component. Variables that are computed using Bs or Es are hardcoded instead.
//
// Component index c is at source offset:
//
//     floor(c / Ns) * Ss + mod(c, Ns) * Bs
//
//   - Flags:
//     * IsAligned: if true, assumes the workgroup size divides the output count, so there is no
//                  need for bound checking.
//     * IsBigEndian
//   - Conversion:
//     * SintToSint: covers byte, short and int types (distinguished by Bs and Bd).
//     * UintToUint: covers ubyte, ushort and uint types (distinguished by Bs and Bd).
//     * SintToFloat: Same types as SintToSint for source (including scaled).  Converts to float.
//     * UintToFloat: Same types as UintToUint for source (including uscaled).  Converst to float.
//     * SnormToFloat: Similar to IntToFloat, but normalized.
//     * UnormToFloat: Similar to UintToFloat, but normalized.
//     * FixedToFloat: 16.16 signed fixed-point to floating point.
//     * FloatToFloat: float.
//     * A2BGR10SintToSint: covers the signed int type of component when format is only A2BGR10.
//     * A2BGR10UintToUint: covers the unsigned int type of component when format is only A2BGR10.
//     * A2BGR10SintToFloat: Same types as A2BGR10SintToSint for source (including scaled).
//                           Converts to float.
//     * A2BGR10UintToFloat: Same types as A2BGR10UintToUint for source (including uscaled).
//                           Converts to float.
//     * A2BGR10SnormToFloat: Similar to IntToFloat, but normalized and only for A2BGR10.
//     * RGB10A2SintToFloat: Same types as RGB10A2SintToSint for source (including scaled).
//                           Converts to float.
//     * RGB10A2UintToFloat: Same types as RGB10A2UintToUint for source (including uscaled).
//                           Converts to float.
//     * RGB10A2SnormToFloat: Similar to IntToFloat, but normalized and only for RGB10A2.
//     * RGB10A2UnormToFloat: Similar to UintToFloat, but normalized and only for RGB10A2.
//     * RGB10X2SintToFloat: Same types as RGB10X2SintToSint for source (including scaled).
//                           Converts to float.
//     * RGB10X2UintToFloat: Same types as RGB10X2UintToUint for source (including uscaled).
//                           Converts to float.
//     * RGB10X2SnormToFloat: Similar to IntToFloat, but normalized and only for RGB10X2.
//     * HalfFloatToHalfFloat: covers half float type only.
//
// SintToSint, UintToUint and FloatToFloat correspond to CopyNativeVertexData() and
// Copy8SintTo16SintVertexData() in renderer/copyvertex.inc, FixedToFloat corresponds to
// Copy32FixedTo32FVertexData, SintToFloat and UintToFloat correspond to CopyTo32FVertexData with
// normalized=false and SnormToFloat and UnormToFloat correspond to CopyTo32FVertexData with
// normalized=true. A2BGR10SintToSint, A2BGR10UintToUint, A2BGR10SintToFloat, A2BGR10UintToFloat
// and A2BGR10SnormToFloat correspond to CopyXYZ10W2ToXYZW32FVertexData with the proper options.
// RGB10A2SintToFloat, RGB10A2UintToFloat, RGB10A2SnormToFloat, and RGB10X2UnormToFloat correspond
// to CopyW2XYZ10ToXYZW32FVertexData, and RGB10X2SintToFloat, RGB10UintToFloat,
// RGB10X2SnormToFloat, and RGB10X2UnormToFloat correspond to CopyXYZ10ToXYZW32FVertexData with
// the proper options. HalfFloatToHalfFloat correspond to CopyNativeVertexData().

#version 450 core

// Source type
#if SintToSint || SintToFloat || A2BGR10SintToSint || A2BGR10SintToFloat || RGB10A2SintToFloat || \
    RGB10X2SintToFloat
#define SrcType int
#elif UintToUint || UintToFloat || A2BGR10UintToUint || A2BGR10UintToFloat || \
    RGB10A2UintToFloat || RGB10X2UintToFloat || HalfFloatToHalfFloat
#define SrcType uint
#elif SnormToFloat || UnormToFloat || FixedToFloat || FloatToFloat || A2BGR10SnormToFloat || \
    RGB10A2SnormToFloat || RGB10A2UnormToFloat || RGB10X2SnormToFloat || RGB10X2UnormToFloat
#define SrcType float
#else
#error "Not all conversions are accounted for"
#endif

// Destination type
#if SintToSint || A2BGR10SintToSint
#define DestType int
#define IsDestFloat 0
#elif UintToUint || A2BGR10UintToUint || HalfFloatToHalfFloat
#define DestType uint
#define IsDestFloat 0
#elif SintToFloat || UintToFloat || SnormToFloat || UnormToFloat || FixedToFloat || FloatToFloat || \
    A2BGR10SintToFloat || A2BGR10UintToFloat || A2BGR10SnormToFloat || \
    RGB10A2SintToFloat || RGB10A2UintToFloat || RGB10A2SnormToFloat || RGB10A2UnormToFloat || \
    RGB10X2SintToFloat || RGB10X2UintToFloat || RGB10X2SnormToFloat || RGB10X2UnormToFloat
#define DestType float
#define IsDestFloat 1
#else
#error "Not all conversions are accounted for"
#endif

layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

layout (set = 0, binding = 0) buffer dest
{
    uint destData[];
};

layout (set = 0, binding = 1) buffer src
{
    uint srcData[];
};

layout (push_constant) uniform PushConstants
{
    // outputs to write (= total number of components / Ed): used for range checking
    uint outputCount;
    // total number of output components: used for range checking
    uint componentCount;
    // source and destination offsets are handled in the shader (instead of binding the buffer with
    // these offsets), as the binding offset requires alignment with
    // minStorageBufferOffsetAlignment, which is impossible to enforce on source, and therefore
    // would limit the usability of the shader.  Note that source is a storage buffer, instead of a
    // uniform buffer, so it wouldn't be affected by the possibly smaller max size of uniform
    // buffers.
    uint srcOffset;
    uint destOffset;

    // Parameters from the above explanation
    uint Ns;       // Number of source components in one vertex attribute
    uint Bs;       // Source component byte size
    uint Ss;       // Source vertex attribyte byte stride
    uint Es;       // Precalculated 4/Bs

    uint Nd;       // Number of destination components in one vertex attribute
    uint Bd;       // Destination component byte size
    uint Sd;       // Precalculated Nd*Bd
    uint Ed;       // Precalculated 4/Bd
} params;

// Define shorthands for more readable formulas:
#define Ns params.Ns
#define Ss params.Ss
#define Nd params.Nd
#define Sd params.Sd

// With fixed-point and float types, Bs and Bd can only be 4, so they are hardcoded for more
// efficiency.
#if FixedToFloat || FloatToFloat
#define Bs 4
#define Es 1
#else
#define Bs params.Bs
#define Es params.Es
#endif

#if IsDestFloat
#define Bd 4
#define Ed 1
#else
#define Bd params.Bd
#define Ed params.Ed
#endif

uint getSourceComponentOffset(uint vertex, uint component)
{
    return vertex * Ss + component * Bs + params.srcOffset;
}

uint getDestinationComponentOffset(uint vertex, uint component)
{
    return vertex * Sd + component * Bd + params.destOffset;
}

uint getShiftBits(uint offset, uint B)
{
    // Given a byte offset, calculates the bit shift required to extract/store a component.
    //
    // On little endian, it implements the following function:
    //
    // Bs == 1: 0->0, 1->8, 2->16, 3->24
    // Bs == 2: 0->0, 2->16   (1 and 3 are impossible values as Bx is assumed to divide Sx)
    // Bs == 4: 0->0          (similarly, 1, 2, and 3 are impossible values)
    //
    // This is simply given by (offset % 4) * 8.
    //
    // On big endian, it implements the following function:
    //
    // Bs == 1: 0->24, 1->16, 2->8, 3->0
    // Bs == 2: 0->16, 2->0
    // Bs == 4: 0->0
    //
    // This is given by (4 - Bx - offset % 4) * 8

    uint shift = (offset % 4) * 8;

    // If big-endian, the most-significant bits contain the first components, so we reverse the
    // shift count.
#if IsBigEndian
    shift = (4 - B) * 8 - shift;
#endif

    return shift;
}

SrcType loadSourceComponent(uint cd)
{
    // cd is component index in the destination buffer
    uint vertex = cd / Nd;
    uint component = cd % Nd;

    // Fill the alpha channel with 1.0f in case of the source format doesn't have an alpha channel
    // For all other components fill in 0
#if !(RGB10X2SintToFloat || RGB10X2UintToFloat || RGB10X2SnormToFloat || RGB10X2UnormToFloat)
    if (component >= Ns)
    {
        return 0;
    }
#endif

    // Load the source component
    uint offset = getSourceComponentOffset(vertex, component);
    uint block = srcData[offset / 4];
    // A2B10G10R10's components are not byte-aligned, hardcoding values for efficiency.
#if A2BGR10SintToSint || A2BGR10UintToUint || A2BGR10SnormToFloat || A2BGR10SintToFloat || \
    A2BGR10UintToFloat
    uint valueBits = component == 3 ? 2 : 10;
    uint shiftBits = 10 * component;
    uint valueMask = component == 3 ? 0x03 : 0x3FF;
#elif RGB10A2SintToFloat || RGB10A2UintToFloat || RGB10A2SnormToFloat || RGB10A2UnormToFloat || \
    RGB10X2SintToFloat || RGB10X2UintToFloat || RGB10X2SnormToFloat || RGB10X2UnormToFloat
    uint valueBits = component == 3 ? 2 : 10;
    // channel order is reversed
    uint shiftBits = component == 3 ? 0 : (valueBits * (2 - component) + 2);
    uint valueMask = component == 3 ? 0x03 : 0x3FF;
#else
    uint shiftBits = getShiftBits(offset, Bs);
    uint valueBits = Bs * 8;
    uint valueMask = valueBits == 32 ? -1 : (1 << valueBits) - 1;
#endif

    // When alpha channel is dummy, fill it with 1.0f
#if RGB10X2SintToFloat || RGB10X2UintToFloat || RGB10X2SnormToFloat
    uint valueAsUint = component == 3 ? 0x01 : (block >> shiftBits) & valueMask;
#elif RGB10X2UnormToFloat
    uint valueAsUint = component == 3 ? 0x03 : (block >> shiftBits) & valueMask;
#else
    uint valueAsUint = (block >> shiftBits) & valueMask;
#endif

    // Convert to SrcType
#if SintToSint || SintToFloat || A2BGR10SintToSint || A2BGR10SintToFloat || RGB10A2SintToFloat \
    || RGB10X2SintToFloat
    if (valueBits < 32)
    {
        bool isNegative = (valueAsUint & (1 << (valueBits - 1))) != 0;
        // Sign extend
        // Note: if valueBits == 32, then 0xFFFFFFFF << valueBits is undefined,
        // causing sign extension of value below to produce incorrect values.
        uint signExtension = isNegative ? 0xFFFFFFFF << valueBits : 0;
        valueAsUint |= signExtension;
    }
    SrcType value = SrcType(valueAsUint);
#elif UintToUint || UintToFloat || A2BGR10UintToUint || A2BGR10UintToFloat || RGB10A2UintToFloat \
    || RGB10X2UintToFloat || HalfFloatToHalfFloat
    SrcType value = valueAsUint;
#elif SnormToFloat || A2BGR10SnormToFloat || RGB10A2SnormToFloat || RGB10X2SnormToFloat
    if (valueBits < 32)
    {
        bool isNegative = (valueAsUint & (1 << (valueBits - 1))) != 0;
        uint signExtension = isNegative ? 0xFFFFFFFF << valueBits : 0;
        valueAsUint |= signExtension;
    }
    int valueAsInt = int(valueAsUint);
    SrcType value = float(valueAsInt) / (valueMask >> 1);
    value = max(value, float(-1));
#elif UnormToFloat || RGB10A2UnormToFloat || RGB10X2UnormToFloat
    float positiveMax = valueMask;
    // Scale [0, P] to [0, 1]
    SrcType value = valueAsUint / positiveMax;
#elif FixedToFloat
    float divisor = 1.0f / 65536.0f;
    SrcType value = int(valueAsUint) * divisor;
#elif FloatToFloat
    SrcType value = uintBitsToFloat(valueAsUint);
#else
#error "Not all conversions are accounted for"
#endif

    return value;
}

DestType convertComponent(SrcType srcValue)
{
    // In all cases, SrcValue already contains the final value, except it may need a cast, which
    // happens implicitly here.
    return srcValue;
}

uint makeDestinationComponent(uint cd, DestType value)
{
    // Return valueAsUint, shifted to the right spot.  Multiple calls to this function should be |ed
    // and eventually written to the destination.

#if SintToSint || UintToUint || A2BGR10SintToSint || A2BGR10UintToUint || HalfFloatToHalfFloat
    uint vertex = cd / Nd;
    uint component = cd % Nd;

    uint offset = getDestinationComponentOffset(vertex, component);
    uint shiftBits = getShiftBits(offset, Bd);

    uint valueBits = Bd * 8;
    uint valueMask = valueBits == 32 ? -1 : (1 << valueBits) - 1;
    uint valueAsUint = (uint(value) & valueMask) << shiftBits;

#elif IsDestFloat
    // If the destination is float, it will occupy the whole result.
    uint valueAsUint = floatBitsToInt(value);

#else
#error "Not all conversions are accounted for"
#endif

    return valueAsUint;
}

void storeDestinationComponents(uint valueAsUint)
{
    // Note that the destination allocations are always aligned to kMaxVertexFormatAlignment.
    destData[gl_GlobalInvocationID.x + params.destOffset / 4] = valueAsUint;
}

void main()
{
#if !IsAligned
    if (gl_GlobalInvocationID.x >= params.outputCount)
        return;
#endif // IsAligned

    uint valueOut = 0;
    for (uint i = 0; i < Ed; ++i)
    {
        uint cd = gl_GlobalInvocationID.x * Ed + i;
#if !IsAligned
        if (cd >= params.componentCount)
        {
            break;
        }
#endif

        SrcType srcValue = loadSourceComponent(cd);
        DestType destValue = convertComponent(srcValue);
        valueOut |= makeDestinationComponent(cd, destValue);
    }

    storeDestinationComponents(valueOut);
}
