Initial public release.
[OpenCLIPER] / src / kernels / complexElementProd.cl
1 /* Copyright (C) 2018 Federico Simmross Wattenberg,
2  *                    Manuel Rodríguez Cayetano,
3  *                    Javier Royuela del Val,
4  *                    Elena Martín González,
5  *                    Elisa Moya Sáez,
6  *                    Marcos Martín Fernández and
7  *                    Carlos Alberola López
8  *
9  * This file is part of OpenCLIPER.
10  *
11  * OpenCLIPER is free software; you can redistribute it and/or modify
12  * it under the terms of the GNU General Public License as published by
13  * the Free Software Foundation; version 3 of the License.
14  *
15  * OpenCLIPER is distributed in the hope that it will be useful, but
16  * WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * General Public License for more details.
19  *
20  * You should have received a copy of the GNU General Public License
21  * along with OpenCLIPER; If not, see <http://www.gnu.org/licenses/>.
22  *
23  *
24  *  Contact:
25  *
26  *  Federico Simmross Wattenberg
27  *  E.T.S.I. Telecomunicación
28  *  Universidad de Valladolid
29  *  Paseo de Belén 15
30  *  47011 Valladolid, Spain.
31  *  fedsim@tel.uva.es
32  */
33 /*
34  * RCS/CVS version control info
35  * $Id: reduce_kernel.cl,v 1.2 2016/11/02 12:34:19 manrod Exp $
36  * $Revision: 1.2 $
37  * $Date: 2016/11/02 12:34:19 $
38  */
39
40 #include <OpenCLIPER/kernels/hostKernelFunctions.h>
41 //#pragma OPENCL_EXTENSION cl_amd_printf
42 #pragma OPENCL EXTENSION cl_amd_printf : enable
43 //#define KERNEL_DEBUG
44 #define complexMul(p1,p2) (float2) ((p1).s0*(p2).s0-(p1).s1*(p2).s1, (p1).s0*(p2).s1+(p1).s1*(p2).s0)
45
46 #define VECTORDATATYPESIZE 16
47 #define VECTORDATATYPEMACRO(baseType,size) {baseType ## size}
48 #define VECTORDATATYPE float16
49 #define VECTORDATATYPEHALFSIZE (VECTORDATATYPESIZE)/2
50 #define HALFVECTORDATATYPE  float8
51 #define MASKDATATYPE uint16
52 #define HALFMASKDATATYPE uint8
53 #define VLOADN vload16
54 #define VSTOREN vstore16
55
56 //#define DEBUG 1
57 #ifdef DEBUG
58 #define PRINTVECTOR(name, vector, numberOfElements) do {printVector(name, vector, numberOfElements);} while (0)
59 #else
60 #define PRINTVECTOR(name, vector, numberOfElements)
61 #endif
62
63 //#define DEBUGKERNEL 1
64
65 void createMasks(HALFMASKDATATYPE* pOddElementsMaskVector, HALFMASKDATATYPE* pEvenElementsMaskVector,
66                  MASKDATATYPE* pMaskCircularLeft1PosVector, MASKDATATYPE* pInterleavedRealAndImagPartsMaskVector,
67                  VECTORDATATYPE* pConjugatePatternVector) {
68     uint* pOddElementsMask = (uint *) pOddElementsMaskVector;
69     uint* pEvenElementsMask = (uint *) pEvenElementsMaskVector;
70     //HALFMASKDATATYPE oddElementsMask = (HALFMASKDATATYPE) (0, 2, 4, 6, 8, 10, 12, 14);
71     //HALFMASKDATATYPE evenElementsMask = (HALFMASKDATATYPE) (1, 3, 5, 7, 9, 11, 13, 15);
72     for (uint i = 0; i < VECTORDATATYPEHALFSIZE; i++) {
73         pOddElementsMask[i] = i * 2;
74         pEvenElementsMask[i] = (i * 2) + 1;
75     }
76
77     //MASKDATATYPE maskCircularLeft1Pos = (MASKDATATYPE) (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0);
78     uint* pMaskCircularLeft1Pos = (uint *) pMaskCircularLeft1PosVector;
79     for (uint i = 0; i < VECTORDATATYPESIZE; i++) {
80         pMaskCircularLeft1Pos[i] = (i+1)%VECTORDATATYPESIZE;
81     }
82
83     //MASKDATATYPE interleavedRealAndImagPartsMask = (MASKDATATYPE) (0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15);
84     uint* pInterleavedRealAndImagPartsMask = (uint *) pInterleavedRealAndImagPartsMaskVector;
85     uint interleavedPartsOffset = VECTORDATATYPEHALFSIZE;
86     for (uint i = 0; i < VECTORDATATYPEHALFSIZE; i ++) {
87         pInterleavedRealAndImagPartsMask[i*2] = i;
88         pInterleavedRealAndImagPartsMask[(i*2)+1] = i + interleavedPartsOffset;
89     }
90     // Pattern for conjugating vector
91     realType* pConjugatePattern = (realType *) pConjugatePatternVector;
92     for (uint i = 0; i < VECTORDATATYPEHALFSIZE; i++) {
93         pConjugatePattern[(i * 2)] = 1;
94         pConjugatePattern[(i * 2) + 1] = -1;
95     }
96 }
97
98 void complexElementWiseProductVector(global realType* pInputBuffer1, global realType* pInputBuffer2,
99                                      global realType* pOutputBuffer, ushort conjugateSecondOperand,
100                                      const HALFMASKDATATYPE* pOddElementsMaskVector,
101                                      const HALFMASKDATATYPE* pEvenElementsMaskVector,
102                                      const MASKDATATYPE* pCircularLeft1PosMaskVector,
103                                      const MASKDATATYPE* pInterleavedRealAndImagPartsMaskVector,
104                                      const VECTORDATATYPE* pConjugatePatternVector) {
105     uint offsetInNumberOfElements = 0;
106     VECTORDATATYPE op1;
107     VECTORDATATYPE op2;
108     VECTORDATATYPE res;
109     //printf("Vector data type: %s\n", VECTORDATATYPEMACRO(realType,VECTORDATATYPESIZE);
110     //op1 = (VECTORDATATYPE) {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; //VLOADN(offsetInNumberOfElements, pInputBuffer1);
111     op1 = VLOADN(offsetInNumberOfElements, pInputBuffer1);
112     //op2 = (VECTORDATATYPE) {-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16};//VLOADN(offsetInNumberOfElements, pInputBuffer2);
113     op2 = VLOADN(offsetInNumberOfElements, pInputBuffer2);
114     PRINTVECTOR("op1", (float *) (&op1), VECTORDATATYPESIZE);
115     PRINTVECTOR("op2", (float *) (&op2), VECTORDATATYPESIZE);
116
117     if(conjugateSecondOperand != 0) {
118         // operations between vectors are also valid between scalar and vector (scalar is widened to the vector number of elements)
119         /*
120         ulong conjugateVectorBits;
121         conjugateVectorBits = 0x80000000;
122         op2 = op2 ^ conjugateVectorBits;
123         */
124         op2 = op2 * (*pConjugatePatternVector);
125         PRINTVECTOR("newop2", (float *) (&op2), VECTORDATATYPESIZE);
126     }
127     /// Vector with product of real parts of 2 input vectors and product of imaginary parts of 2 input vectors interleaved
128     VECTORDATATYPE realAndImagPartsProductVector = op1 * op2;
129     PRINTVECTOR("realAndImagPartsProductVector",(float *) (&realAndImagPartsProductVector), VECTORDATATYPESIZE);
130     /// Vector with product of real parts of 2 input vectors
131     HALFVECTORDATATYPE realPartsProductVector = shuffle (realAndImagPartsProductVector, *pOddElementsMaskVector); // get only odd elements
132     /// Vector with product of imaginary parts of 2 input vectors
133     HALFVECTORDATATYPE imagPartsProductVector = shuffle (realAndImagPartsProductVector, *pEvenElementsMaskVector); // get only even elements
134     PRINTVECTOR("realPartsProductVector", (float *) &realPartsProductVector, VECTORDATATYPEHALFSIZE);
135     PRINTVECTOR("imagPartsProductVector", (float *) &imagPartsProductVector, VECTORDATATYPEHALFSIZE);
136     /// Vector with real parts of result vector
137     HALFVECTORDATATYPE realPartsResultVector = realPartsProductVector - imagPartsProductVector; // parallel element-wise substraction
138     PRINTVECTOR("realPartsResultVector", (float *) &realPartsResultVector, VECTORDATATYPEHALFSIZE);
139
140     /// Second input vector with elements shift by 1 position to the left (circular shift)
141     VECTORDATATYPE op2CircLeftShift1 = shuffle(op2, *pCircularLeft1PosMaskVector); // circular left shift by 1 position
142     /// Vector with product of real parts of first input vector multiplied by imaginary parts of second input vector
143     VECTORDATATYPE realXImagPartsVector = op1 * op2CircLeftShift1;
144     /// First input vector with elements shift by 1 position to the left (circular shift)
145     VECTORDATATYPE op1CircLeftShift1 = shuffle(op1, *pCircularLeft1PosMaskVector); // circular left shift by 1 position
146     /// Vector with product of imaginary parts of first input vector multiplied by real parts of second input vector
147     VECTORDATATYPE imagXRealPartsVector = op1CircLeftShift1 * op2;
148     /// Vector with imaginary parts of result vector in odd positions (even positions are not valid)
149     VECTORDATATYPE imagPartsOddElementsVector = realXImagPartsVector + imagXRealPartsVector; // (parallel element-wise addition)
150     PRINTVECTOR("imagPartsOddElementsVector", (float *) &imagPartsOddElementsVector, VECTORDATATYPESIZE);
151     /// Vector with imaginary parts of result vector
152     HALFVECTORDATATYPE imagPartsResultVector = shuffle(imagPartsOddElementsVector,*pOddElementsMaskVector); // get only odd elements
153     /// Result vector build by interleaving elements of vector with real parts and vector with imaginary parts of the operation result
154     res = shuffle2(realPartsResultVector, imagPartsResultVector, *pInterleavedRealAndImagPartsMaskVector);
155     PRINTVECTOR("result", (float *) &res, VECTORDATATYPESIZE);
156     VSTOREN(res, offsetInNumberOfElements, pOutputBuffer);
157     offsetInNumberOfElements = offsetInNumberOfElements + get_global_size(0);
158 #ifdef DEBUG
159     printf("new offset in number of elements: %d\n", offsetInNumberOfElements);
160 #endif
161 }
162
163 __kernel void complexElementProd_kernel(__global realType* pInputBuffer1, __global realType* pInputBuffer2,
164                                                              __global realType* pOutputBuffer, ushort conjugateSecondOperand,
165                                                              __global uint* inputDims, __global uint* sensitivityMapsDims, __global uint* outputDims,
166                                                              __global uint* inputStrides, __global uint* sensitivityMapsStrides,
167                                                              __global uint* outputStrides)  {
168     /// Mask for extracting odd elements from vector
169     HALFMASKDATATYPE oddElementsMaskVector;
170     /// Mask for extracting event elements from vector
171     HALFMASKDATATYPE evenElementsMaskVector;
172     /// Mask for circular rotating 1 pos to the left a vector
173     MASKDATATYPE circularLeft1PosMaskVector;
174     /**
175       * Mask for building a vector with real and imaginary parts of complex numbers interleaved from a vector of real parts and
176       * a vector of imaginary parts
177       */
178     MASKDATATYPE interleavedRealAndImagPartsMaskVector;
179     /** Mask for conjugating a vector */
180     VECTORDATATYPE conjugatePatternVector;
181     createMasks(&oddElementsMaskVector, &evenElementsMaskVector, &circularLeft1PosMaskVector,
182                 &interleavedRealAndImagPartsMaskVector, &conjugatePatternVector);
183     PRINTF(("NSD: %d\tAllsizesEqual: %d\tNCoils: %d\tNTD: %d\n", inputDims[NumSpatialDimsPos], inputDims[AllSizesEqualPos], inputDims[NumCoilsPos],
184            inputDims[NumTemporalDimsPos]));
185     uint inputOffsetFrameId, outputOffsetFrameId, inputOffsetCoilId, sensitivityMapsOffsetCoilId, outputOffsetCoilId,
186         inputOffsetFrameAndCoilId, outputOffsetFrameAndCoilId, inputIndexRealPartElement, sensitivityMapsIndexRealPartElement,
187         outputIndexRealPartElement;
188
189     uint frameDimIndex = 0;
190     uint frameId, coilId, elementIndex1D;
191     uint numCoils, numFrames;
192     //numCoils = sensitivityMapsDims.numCoils;
193     numCoils = getNumCoils(sensitivityMapsDims);
194     numFrames = getTemporalDimSize(inputDims, frameDimIndex);
195     PRINTF(("Starting complexElementWiseProduct_kernel...\n"));
196     frameId = get_global_id(2);
197     //PRINTF(("frameId: %d\nnumFrames: %d\n", frameId, inputDims.numFrames));
198     while (frameId < numFrames) {
199         PRINTF(("frameId: %d\n", frameId));
200         inputOffsetFrameId = frameId * getTemporalDimStride(inputDims, inputStrides, frameDimIndex, 0);
201         outputOffsetFrameId = frameId * getTemporalDimStride(outputDims, outputStrides, frameDimIndex, 0);
202         coilId = get_global_id(1);
203         PRINTF(("coilId: %d\nnumCoils: %d\n", coilId, numCoils));
204         while (coilId < numCoils) {
205             PRINTF(("coilId: %d\n", coilId));
206             if (getNumCoils(inputDims) == 0) { /* input data is XData */
207                 inputOffsetCoilId = 0;
208             } else { /* input data is KData */
209                 inputOffsetCoilId = coilId * getCoilStride(inputDims, inputStrides, 0);
210             }
211             sensitivityMapsOffsetCoilId = coilId * getCoilStride(sensitivityMapsDims, sensitivityMapsStrides, 0);
212             if (getNumCoils(outputDims) == 0) { /* output data is XData */
213                 outputOffsetCoilId = 0;
214             } else { /* output data is KData */
215                 outputOffsetCoilId = coilId * getCoilStride(outputDims, outputStrides, 0);
216             }
217             inputOffsetFrameAndCoilId = inputOffsetFrameId + inputOffsetCoilId;
218             outputOffsetFrameAndCoilId = outputOffsetFrameId + outputOffsetCoilId;
219             elementIndex1D = get_global_id(0) * VECTORDATATYPESIZE;
220             /*
221             PRINTF(("elementIndex1D: %d\nelementIndex1DMaxValue: %d\n", elementIndex1D,
222                    inputDims.width * inputDims.height * inputDims.depth * 2));
223             */
224             while (elementIndex1D < (getNDArrayTotalSize(inputDims, 0) * 2)) {
225                 //PRINTF(("frameId: %d\ncoilId: %d\nelementIndex1D:%d\n", frameId, coilId, elementIndex1D));
226                 inputIndexRealPartElement = inputOffsetFrameAndCoilId + elementIndex1D;
227                 sensitivityMapsIndexRealPartElement = sensitivityMapsOffsetCoilId + elementIndex1D;
228                 outputIndexRealPartElement = outputOffsetFrameAndCoilId + elementIndex1D;
229                 //PRINTF(("inputBuffer[inputIndexRealPartElement: %d]: %f\n", inputIndexRealPartElement, pInputBuffer1[inputIndexRealPartElement]));
230                 /*PRINTF(("sensitivityMap[sensitivityMapsIndexRealPartElement: %d]: %f\n", sensitivityMapsIndexRealPartElement,
231                        pInputBuffer2[sensitivityMapsIndexRealPartElement]));*/
232                 complexElementWiseProductVector(&(pInputBuffer1[inputIndexRealPartElement]),
233                                                 &(pInputBuffer2[sensitivityMapsIndexRealPartElement]),
234                                                 &(pOutputBuffer[outputIndexRealPartElement]), conjugateSecondOperand,
235                                                 &oddElementsMaskVector, &evenElementsMaskVector, &circularLeft1PosMaskVector,
236                                                 &interleavedRealAndImagPartsMaskVector, &conjugatePatternVector);
237                 //PRINTF(("outputBuffer[outputIndexRealPartElement: %d]: %f\n", outputIndexRealPartElement, pOutputBuffer[outputIndexRealPartElement]));
238                 elementIndex1D += get_global_size(0) * VECTORDATATYPESIZE;
239             }
240             //printVector("Result: ", (float *)pOutputBuffer, numColumns * 2);
241             coilId += get_global_size(1);
242         }
243         frameId += get_global_size(2);
244     }
245     PRINTF(("done.\n"));
246 }