reshape.glsl (2426B)
1 /* See LICENSE for license details. */ 2 3 #if InputDataKind == DataKind_Float32Complex 4 #define Input Float32Complex 5 #elif InputDataKind == DataKind_Float32 6 #define Input Float32 7 #elif InputDataKind == DataKind_Float16Complex || InputDataKind == DataKind_Int16Complex 8 #define Input Int16Complex 9 #elif InputDataKind == DataKind_Float16 || InputDataKind == DataKind_Int16 10 #define Input Int16 11 #else 12 #error unsupported data kind for Reshape 13 #endif 14 15 #if OutputDataKind == DataKind_Float32Complex 16 #define Output Float32Complex 17 #define OutputKind f32vec2 18 #elif OutputDataKind == DataKind_Float32 19 #define Output Float32 20 #define OutputKind f32 21 #elif OutputDataKind == DataKind_Float16Complex 22 #define Output Float16Complex 23 #define OutputKind f16vec2 24 #elif OutputDataKind == DataKind_Float16 25 #define Output Float16 26 #define OutputKind f16 27 #elif OutputDataKind == DataKind_Int16Complex 28 #define Output Int16Complex 29 #define OutputKind s16vec2 30 #elif OutputDataKind == DataKind_Int16 31 #define Output Int16 32 #define OutputKind s16 33 #else 34 #error unsupported data kind for Reshape 35 #endif 36 37 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Int16 { 38 s16 x[]; 39 }; 40 41 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Int16Complex { 42 s16vec2 x[]; 43 }; 44 45 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float16 { 46 f16 x[]; 47 }; 48 49 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float16Complex { 50 f16vec2 x[]; 51 }; 52 53 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float32 { 54 f32 x[]; 55 }; 56 57 layout(std430, buffer_reference, buffer_reference_align = 8) restrict buffer Float32Complex { 58 f32vec2 x[]; 59 }; 60 61 void main(void) 62 { 63 if (all(lessThan(gl_GlobalInvocationID, uvec3(SizeX, SizeY, SizeZ)))) { 64 u32 x = gl_GlobalInvocationID.x; 65 u32 y = gl_GlobalInvocationID.y; 66 u32 z = gl_GlobalInvocationID.z; 67 68 u32 input_index = InputStrideX * x + InputStrideY * y + InputStrideZ * z; 69 u32 output_index = OutputStrideX * x + OutputStrideY * y + OutputStrideZ * z; 70 71 OutputKind out_value = OutputKind(0); 72 73 #if Interleave 74 out_value[0] = Input(left_input_buffer).x[input_index]; 75 out_value[1] = Input(right_input_buffer).x[input_index]; 76 #else 77 out_value = Input(left_input_buffer).x[input_index]; 78 #endif 79 80 Output(output_buffer).x[output_index] = out_value; 81 } 82 }