ViennaCL - The Vienna Computing Library
1.5.2
|
00001 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_SVD_HPP 00002 #define VIENNACL_LINALG_OPENCL_KERNELS_SVD_HPP 00003 00004 #include "viennacl/tools/tools.hpp" 00005 #include "viennacl/ocl/kernel.hpp" 00006 #include "viennacl/ocl/platform.hpp" 00007 #include "viennacl/ocl/utils.hpp" 00008 00011 namespace viennacl 00012 { 00013 namespace linalg 00014 { 00015 namespace opencl 00016 { 00017 namespace kernels 00018 { 00019 template <typename StringType> 00020 void generate_svd_bidiag_pack(StringType & source, std::string const & numeric_string) 00021 { 00022 source.append("__kernel void bidiag_pack(__global "); source.append(numeric_string); source.append("* A, \n"); 00023 source.append(" __global "); source.append(numeric_string); source.append("* D, \n"); 00024 source.append(" __global "); source.append(numeric_string); source.append("* S, \n"); 00025 source.append(" uint size1, \n"); 00026 source.append(" uint size2, \n"); 00027 source.append(" uint stride \n"); 00028 source.append(") { \n"); 00029 source.append(" uint size = min(size1, size2); \n"); 00030 00031 source.append(" if(get_global_id(0) == 0) \n"); 00032 source.append(" S[0] = 0; \n"); 00033 00034 source.append(" for(uint i = get_global_id(0); i < size ; i += get_global_size(0)) { \n"); 00035 source.append(" D[i] = A[i*stride + i]; \n"); 00036 source.append(" S[i + 1] = (i + 1 < size2) ? A[i*stride + (i + 1)] : 0; \n"); 00037 source.append(" } \n"); 00038 source.append("} \n"); 00039 } 00040 00041 template <typename StringType> 00042 void generate_svd_col_reduce_lcl_array(StringType & source, std::string const & numeric_string) 00043 { 00044 // calculates a sum of local array elements 00045 source.append("void col_reduce_lcl_array(__local "); source.append(numeric_string); source.append("* sums, uint lcl_id, uint lcl_sz) { \n"); 00046 source.append(" uint step = lcl_sz >> 1; \n"); 00047 00048 source.append(" while(step > 0) { \n"); 00049 source.append(" if(lcl_id < step) { \n"); 00050 source.append(" sums[lcl_id] += sums[lcl_id + step]; \n"); 00051 source.append(" } \n"); 00052 source.append(" step >>= 1; \n"); 00053 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00054 source.append(" } \n"); 00055 source.append("} \n"); 00056 } 00057 00058 template <typename StringType> 00059 void generate_svd_copy_col(StringType & source, std::string const & numeric_string) 00060 { 00061 // probably, this is a ugly way 00062 source.append("__kernel void copy_col(__global "); source.append(numeric_string); source.append("* A, \n"); 00063 source.append(" __global "); source.append(numeric_string); source.append("* V, \n"); 00064 source.append(" uint row_start, \n"); 00065 source.append(" uint col_start, \n"); 00066 source.append(" uint size, \n"); 00067 source.append(" uint stride \n"); 00068 source.append(" ) { \n"); 00069 source.append(" uint glb_id = get_global_id(0); \n"); 00070 source.append(" uint glb_sz = get_global_size(0); \n"); 00071 00072 source.append(" for(uint i = row_start + glb_id; i < size; i += glb_sz) { \n"); 00073 source.append(" V[i - row_start] = A[i * stride + col_start]; \n"); 00074 source.append(" } \n"); 00075 source.append("} \n"); 00076 } 00077 00078 template <typename StringType> 00079 void generate_svd_copy_row(StringType & source, std::string const & numeric_string) 00080 { 00081 // probably, this is too 00082 source.append("__kernel void copy_row(__global "); source.append(numeric_string); source.append("* A, \n"); 00083 source.append(" __global "); source.append(numeric_string); source.append("* V, \n"); 00084 source.append(" uint row_start, \n"); 00085 source.append(" uint col_start, \n"); 00086 source.append(" uint size, \n"); 00087 source.append(" uint stride \n"); 00088 source.append(" ) { \n"); 00089 source.append(" uint glb_id = get_global_id(0); \n"); 00090 source.append(" uint glb_sz = get_global_size(0); \n"); 00091 00092 source.append(" for(uint i = col_start + glb_id; i < size; i += glb_sz) { \n"); 00093 source.append(" V[i - col_start] = A[row_start * stride + i]; \n"); 00094 source.append(" } \n"); 00095 source.append("} \n"); 00096 } 00097 00098 template <typename StringType> 00099 void generate_svd_final_iter_update(StringType & source, std::string const & numeric_string) 00100 { 00101 source.append("__kernel void final_iter_update(__global "); source.append(numeric_string); source.append("* A, \n"); 00102 source.append(" uint stride, \n"); 00103 source.append(" uint n, \n"); 00104 source.append(" uint last_n, \n"); 00105 source.append(" "); source.append(numeric_string); source.append(" q, \n"); 00106 source.append(" "); source.append(numeric_string); source.append(" p \n"); 00107 source.append(" ) \n"); 00108 source.append("{ \n"); 00109 source.append(" uint glb_id = get_global_id(0); \n"); 00110 source.append(" uint glb_sz = get_global_size(0); \n"); 00111 00112 source.append(" for (uint px = glb_id; px < last_n; px += glb_sz) \n"); 00113 source.append(" { \n"); 00114 source.append(" "); source.append(numeric_string); source.append(" v_in = A[n * stride + px]; \n"); 00115 source.append(" "); source.append(numeric_string); source.append(" z = A[(n - 1) * stride + px]; \n"); 00116 source.append(" A[(n - 1) * stride + px] = q * z + p * v_in; \n"); 00117 source.append(" A[n * stride + px] = q * v_in - p * z; \n"); 00118 source.append(" } \n"); 00119 source.append("} \n"); 00120 } 00121 00122 template <typename StringType> 00123 void generate_svd_givens_next(StringType & source, std::string const & numeric_string) 00124 { 00125 source.append("__kernel void givens_next(__global "); source.append(numeric_string); source.append("* matr, \n"); 00126 source.append(" __global "); source.append(numeric_string); source.append("* cs, \n"); 00127 source.append(" __global "); source.append(numeric_string); source.append("* ss, \n"); 00128 source.append(" uint size, \n"); 00129 source.append(" uint stride, \n"); 00130 source.append(" uint start_i, \n"); 00131 source.append(" uint end_i \n"); 00132 source.append(" ) \n"); 00133 source.append("{ \n"); 00134 source.append(" uint glb_id = get_global_id(0); \n"); 00135 source.append(" uint glb_sz = get_global_size(0); \n"); 00136 00137 source.append(" uint lcl_id = get_local_id(0); \n"); 00138 source.append(" uint lcl_sz = get_local_size(0); \n"); 00139 00140 source.append(" uint j = glb_id; \n"); 00141 00142 source.append(" __local "); source.append(numeric_string); source.append(" cs_lcl[256]; \n"); 00143 source.append(" __local "); source.append(numeric_string); source.append(" ss_lcl[256]; \n"); 00144 00145 source.append(" "); source.append(numeric_string); source.append(" x = (j < size) ? matr[(end_i + 1) * stride + j] : 0; \n"); 00146 00147 source.append(" uint elems_num = end_i - start_i + 1; \n"); 00148 source.append(" uint block_num = (elems_num + lcl_sz - 1) / lcl_sz; \n"); 00149 00150 source.append(" for(uint block_id = 0; block_id < block_num; block_id++) \n"); 00151 source.append(" { \n"); 00152 source.append(" uint to = min(elems_num - block_id * lcl_sz, lcl_sz); \n"); 00153 00154 source.append(" if(lcl_id < to) \n"); 00155 source.append(" { \n"); 00156 source.append(" cs_lcl[lcl_id] = cs[end_i - (lcl_id + block_id * lcl_sz)]; \n"); 00157 source.append(" ss_lcl[lcl_id] = ss[end_i - (lcl_id + block_id * lcl_sz)]; \n"); 00158 source.append(" } \n"); 00159 00160 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00161 00162 source.append(" if(j < size) \n"); 00163 source.append(" { \n"); 00164 source.append(" for(uint ind = 0; ind < to; ind++) \n"); 00165 source.append(" { \n"); 00166 source.append(" uint i = end_i - (ind + block_id * lcl_sz); \n"); 00167 00168 source.append(" "); source.append(numeric_string); source.append(" z = matr[i * stride + j]; \n"); 00169 00170 source.append(" "); source.append(numeric_string); source.append(" cs_val = cs_lcl[ind]; \n"); 00171 source.append(" "); source.append(numeric_string); source.append(" ss_val = ss_lcl[ind]; \n"); 00172 00173 source.append(" matr[(i + 1) * stride + j] = x * cs_val + z * ss_val; \n"); 00174 source.append(" x = -x * ss_val + z * cs_val; \n"); 00175 source.append(" } \n"); 00176 source.append(" } \n"); 00177 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00178 source.append(" } \n"); 00179 source.append(" if(j < size) \n"); 00180 source.append(" matr[(start_i) * stride + j] = x; \n"); 00181 source.append("} \n"); 00182 } 00183 00184 template <typename StringType> 00185 void generate_svd_givens_prev(StringType & source, std::string const & numeric_string) 00186 { 00187 source.append("__kernel void givens_prev(__global "); source.append(numeric_string); source.append("* matr, \n"); 00188 source.append(" __global "); source.append(numeric_string); source.append("* cs, \n"); 00189 source.append(" __global "); source.append(numeric_string); source.append("* ss, \n"); 00190 source.append(" uint size, \n"); 00191 source.append(" uint stride, \n"); 00192 source.append(" uint start_i, \n"); 00193 source.append(" uint end_i \n"); 00194 source.append(" ) \n"); 00195 source.append("{ \n"); 00196 source.append(" uint glb_id = get_global_id(0); \n"); 00197 source.append(" uint glb_sz = get_global_size(0); \n"); 00198 00199 source.append(" uint lcl_id = get_local_id(0); \n"); 00200 source.append(" uint lcl_sz = get_local_size(0); \n"); 00201 00202 source.append(" uint j = glb_id; \n"); 00203 00204 source.append(" __local "); source.append(numeric_string); source.append(" cs_lcl[256]; \n"); 00205 source.append(" __local "); source.append(numeric_string); source.append(" ss_lcl[256]; \n"); 00206 00207 source.append(" "); source.append(numeric_string); source.append(" x = (j < size) ? matr[(start_i - 1) * stride + j] : 0; \n"); 00208 00209 source.append(" uint elems_num = end_i - start_i; \n"); 00210 source.append(" uint block_num = (elems_num + lcl_sz - 1) / lcl_sz; \n"); 00211 00212 source.append(" for(uint block_id = 0; block_id < block_num; block_id++) \n"); 00213 source.append(" { \n"); 00214 source.append(" uint to = min(elems_num - block_id * lcl_sz, lcl_sz); \n"); 00215 00216 source.append(" if(lcl_id < to) \n"); 00217 source.append(" { \n"); 00218 source.append(" cs_lcl[lcl_id] = cs[lcl_id + start_i + block_id * lcl_sz]; \n"); 00219 source.append(" ss_lcl[lcl_id] = ss[lcl_id + start_i + block_id * lcl_sz]; \n"); 00220 source.append(" } \n"); 00221 00222 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00223 00224 source.append(" if(j < size) \n"); 00225 source.append(" { \n"); 00226 source.append(" for(uint ind = 0; ind < to; ind++) \n"); 00227 source.append(" { \n"); 00228 source.append(" uint i = ind + start_i + block_id * lcl_sz; \n"); 00229 00230 source.append(" "); source.append(numeric_string); source.append(" z = matr[i * stride + j]; \n"); 00231 00232 source.append(" "); source.append(numeric_string); source.append(" cs_val = cs_lcl[ind];//cs[i]; \n"); 00233 source.append(" "); source.append(numeric_string); source.append(" ss_val = ss_lcl[ind];//ss[i]; \n"); 00234 00235 source.append(" matr[(i - 1) * stride + j] = x * cs_val + z * ss_val; \n"); 00236 source.append(" x = -x * ss_val + z * cs_val; \n"); 00237 source.append(" } \n"); 00238 source.append(" } \n"); 00239 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00240 source.append(" } \n"); 00241 source.append(" if(j < size) \n"); 00242 source.append(" matr[(end_i - 1) * stride + j] = x; \n"); 00243 source.append("} \n"); 00244 } 00245 00246 template <typename StringType> 00247 void generate_svd_house_update_A_left(StringType & source, std::string const & numeric_string) 00248 { 00249 source.append("__kernel void house_update_A_left( \n"); 00250 source.append(" __global "); source.append(numeric_string); source.append("* A, \n"); 00251 source.append(" __constant "); source.append(numeric_string); source.append("* V, \n"); //householder vector 00252 source.append(" uint row_start, \n"); 00253 source.append(" uint col_start, \n"); 00254 source.append(" uint size1, \n"); 00255 source.append(" uint size2, \n"); 00256 source.append(" uint stride, \n"); 00257 source.append(" __local "); source.append(numeric_string); source.append("* sums \n"); 00258 source.append(" ) { \n"); 00259 source.append(" uint glb_id = get_global_id(0); \n"); 00260 source.append(" uint glb_sz = get_global_size(0); \n"); 00261 00262 source.append(" uint grp_id = get_group_id(0); \n"); 00263 source.append(" uint grp_nm = get_num_groups(0); \n"); 00264 00265 source.append(" uint lcl_id = get_local_id(0); \n"); 00266 source.append(" uint lcl_sz = get_local_size(0); \n"); 00267 00268 source.append(" "); source.append(numeric_string); source.append(" ss = 0; \n"); 00269 00270 // doing it in slightly different way to avoid cache misses 00271 source.append(" for(uint i = glb_id + col_start; i < size2; i += glb_sz) { \n"); 00272 source.append(" ss = 0; \n"); 00273 source.append(" for(uint j = row_start; j < size1; j++) ss = ss + (V[j] * A[j * stride + i]); \n"); 00274 00275 source.append(" for(uint j = row_start; j < size1; j++) \n"); 00276 source.append(" A[j * stride + i] = A[j * stride + i] - (2 * V[j] * ss); \n"); 00277 source.append(" } \n"); 00278 source.append("} \n"); 00279 } 00280 00281 template <typename StringType> 00282 void generate_svd_house_update_A_right(StringType & source, std::string const & numeric_string) 00283 { 00284 00285 source.append("__kernel void house_update_A_right( \n"); 00286 source.append(" __global "); source.append(numeric_string); source.append("* A, \n"); 00287 source.append(" __global "); source.append(numeric_string); source.append("* V, \n"); // householder vector 00288 source.append(" uint row_start, \n"); 00289 source.append(" uint col_start, \n"); 00290 source.append(" uint size1, \n"); 00291 source.append(" uint size2, \n"); 00292 source.append(" uint stride, \n"); 00293 source.append(" __local "); source.append(numeric_string); source.append("* sums \n"); 00294 source.append(" ) { \n"); 00295 00296 source.append(" uint glb_id = get_global_id(0); \n"); 00297 00298 source.append(" uint grp_id = get_group_id(0); \n"); 00299 source.append(" uint grp_nm = get_num_groups(0); \n"); 00300 00301 source.append(" uint lcl_id = get_local_id(0); \n"); 00302 source.append(" uint lcl_sz = get_local_size(0); \n"); 00303 00304 source.append(" "); source.append(numeric_string); source.append(" ss = 0; \n"); 00305 00306 // update of A matrix 00307 source.append(" for(uint i = grp_id + row_start; i < size1; i += grp_nm) { \n"); 00308 source.append(" ss = 0; \n"); 00309 00310 source.append(" for(uint j = lcl_id; j < size2; j += lcl_sz) ss = ss + (V[j] * A[i * stride + j]); \n"); 00311 source.append(" sums[lcl_id] = ss; \n"); 00312 00313 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00314 source.append(" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n"); 00315 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00316 00317 source.append(" "); source.append(numeric_string); source.append(" sum_Av = sums[0]; \n"); 00318 00319 source.append(" for(uint j = lcl_id; j < size2; j += lcl_sz) \n"); 00320 source.append(" A[i * stride + j] = A[i * stride + j] - (2 * V[j] * sum_Av); \n"); 00321 source.append(" } \n"); 00322 source.append("} \n"); 00323 00324 } 00325 00326 template <typename StringType> 00327 void generate_svd_house_update_QL(StringType & source, std::string const & numeric_string) 00328 { 00329 source.append("__kernel void house_update_QL( \n"); 00330 source.append(" __global "); source.append(numeric_string); source.append("* QL, \n"); 00331 source.append(" __constant "); source.append(numeric_string); source.append("* V, \n"); //householder vector 00332 source.append(" uint size1, \n"); 00333 source.append(" uint size2, \n"); 00334 source.append(" uint strideQ, \n"); 00335 source.append(" __local "); source.append(numeric_string); source.append("* sums \n"); 00336 source.append(" ) { \n"); 00337 source.append(" uint glb_id = get_global_id(0); \n"); 00338 source.append(" uint glb_sz = get_global_size(0); \n"); 00339 00340 source.append(" uint grp_id = get_group_id(0); \n"); 00341 source.append(" uint grp_nm = get_num_groups(0); \n"); 00342 00343 source.append(" uint lcl_id = get_local_id(0); \n"); 00344 source.append(" uint lcl_sz = get_local_size(0); \n"); 00345 00346 source.append(" "); source.append(numeric_string); source.append(" ss = 0; \n"); 00347 // update of left matrix 00348 source.append(" for(uint i = grp_id; i < size1; i += grp_nm) { \n"); 00349 source.append(" ss = 0; \n"); 00350 source.append(" for(uint j = lcl_id; j < size1; j += lcl_sz) ss = ss + (V[j] * QL[i * strideQ + j]); \n"); 00351 source.append(" sums[lcl_id] = ss; \n"); 00352 00353 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00354 source.append(" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n"); 00355 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00356 00357 source.append(" "); source.append(numeric_string); source.append(" sum_Qv = sums[0]; \n"); 00358 00359 source.append(" for(uint j = lcl_id; j < size1; j += lcl_sz) \n"); 00360 source.append(" QL[i * strideQ + j] = QL[i * strideQ + j] - (2 * V[j] * sum_Qv); \n"); 00361 source.append(" } \n"); 00362 source.append("} \n"); 00363 00364 } 00365 00366 template <typename StringType> 00367 void generate_svd_house_update_QR(StringType & source, std::string const & numeric_string) 00368 { 00369 source.append("__kernel void house_update_QR( \n"); 00370 source.append(" __global "); source.append(numeric_string); source.append("* QR, \n"); 00371 source.append(" __global "); source.append(numeric_string); source.append("* V, \n"); // householder vector 00372 source.append(" uint size1, \n"); 00373 source.append(" uint size2, \n"); 00374 source.append(" uint strideQ, \n"); 00375 source.append(" __local "); source.append(numeric_string); source.append("* sums \n"); 00376 source.append(" ) { \n"); 00377 00378 source.append(" uint glb_id = get_global_id(0); \n"); 00379 00380 source.append(" uint grp_id = get_group_id(0); \n"); 00381 source.append(" uint grp_nm = get_num_groups(0); \n"); 00382 00383 source.append(" uint lcl_id = get_local_id(0); \n"); 00384 source.append(" uint lcl_sz = get_local_size(0); \n"); 00385 00386 source.append(" "); source.append(numeric_string); source.append(" ss = 0; \n"); 00387 00388 // update of QR matrix 00389 // Actually, we are calculating a transpose of right matrix. This allows to avoid cache 00390 // misses. 00391 source.append(" for(uint i = grp_id; i < size2; i += grp_nm) { \n"); 00392 source.append(" ss = 0; \n"); 00393 source.append(" for(uint j = lcl_id; j < size2; j += lcl_sz) ss = ss + (V[j] * QR[i * strideQ + j]); \n"); 00394 source.append(" sums[lcl_id] = ss; \n"); 00395 00396 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00397 source.append(" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n"); 00398 source.append(" barrier(CLK_LOCAL_MEM_FENCE); \n"); 00399 00400 source.append(" "); source.append(numeric_string); source.append(" sum_Qv = sums[0]; \n"); 00401 source.append(" for(uint j = lcl_id; j < size2; j += lcl_sz) \n"); 00402 source.append(" QR[i * strideQ + j] = QR[i * strideQ + j] - (2 * V[j] * sum_Qv); \n"); 00403 source.append(" } \n"); 00404 source.append("} \n"); 00405 } 00406 00407 template <typename StringType> 00408 void generate_svd_inverse_signs(StringType & source, std::string const & numeric_string) 00409 { 00410 source.append("__kernel void inverse_signs(__global "); source.append(numeric_string); source.append("* v, \n"); 00411 source.append(" __global "); source.append(numeric_string); source.append("* signs, \n"); 00412 source.append(" uint size, \n"); 00413 source.append(" uint stride \n"); 00414 source.append(" ) \n"); 00415 source.append("{ \n"); 00416 source.append(" uint glb_id_x = get_global_id(0); \n"); 00417 source.append(" uint glb_id_y = get_global_id(1); \n"); 00418 00419 source.append(" if((glb_id_x < size) && (glb_id_y < size)) \n"); 00420 source.append(" v[glb_id_x * stride + glb_id_y] *= signs[glb_id_x]; \n"); 00421 source.append("} \n"); 00422 00423 } 00424 00425 template <typename StringType> 00426 void generate_svd_transpose_inplace(StringType & source, std::string const & numeric_string) 00427 { 00428 00429 source.append("__kernel void transpose_inplace(__global "); source.append(numeric_string); source.append("* input, \n"); 00430 source.append(" unsigned int row_num, \n"); 00431 source.append(" unsigned int col_num) { \n"); 00432 source.append(" unsigned int size = row_num * col_num; \n"); 00433 source.append(" for(unsigned int i = get_global_id(0); i < size; i+= get_global_size(0)) { \n"); 00434 source.append(" unsigned int row = i / col_num; \n"); 00435 source.append(" unsigned int col = i - row*col_num; \n"); 00436 00437 source.append(" unsigned int new_pos = col * row_num + row; \n"); 00438 00439 //new_pos = (col < row) ? 0 : 1; 00440 //input[i] = new_pos; 00441 00442 source.append(" if(i < new_pos) { \n"); 00443 source.append(" "); source.append(numeric_string); source.append(" val = input[i]; \n"); 00444 source.append(" input[i] = input[new_pos]; \n"); 00445 source.append(" input[new_pos] = val; \n"); 00446 source.append(" } \n"); 00447 source.append(" } \n"); 00448 source.append("} \n"); 00449 00450 } 00451 00452 template <typename StringType> 00453 void generate_svd_update_qr_column(StringType & source, std::string const & numeric_string) 00454 { 00455 source.append("__kernel void update_qr_column(__global "); source.append(numeric_string); source.append("* A, \n"); 00456 source.append(" uint stride, \n"); 00457 source.append(" __global "); source.append(numeric_string); source.append("* buf, \n"); 00458 source.append(" int m, \n"); 00459 source.append(" int n, \n"); 00460 source.append(" int last_n) \n"); 00461 source.append("{ \n"); 00462 source.append(" uint glb_id = get_global_id(0); \n"); 00463 source.append(" uint glb_sz = get_global_size(0); \n"); 00464 00465 source.append(" for (int i = glb_id; i < last_n; i += glb_sz) \n"); 00466 source.append(" { \n"); 00467 source.append(" "); source.append(numeric_string); source.append(" a_ik = A[m * stride + i], a_ik_1, a_ik_2; \n"); 00468 00469 source.append(" a_ik_1 = A[(m + 1) * stride + i]; \n"); 00470 00471 source.append(" for(int k = m; k < n; k++) \n"); 00472 source.append(" { \n"); 00473 source.append(" bool notlast = (k != n - 1); \n"); 00474 00475 source.append(" "); source.append(numeric_string); source.append(" p = buf[5 * k] * a_ik + buf[5 * k + 1] * a_ik_1; \n"); 00476 00477 source.append(" if (notlast) \n"); 00478 source.append(" { \n"); 00479 source.append(" a_ik_2 = A[(k + 2) * stride + i]; \n"); 00480 source.append(" p = p + buf[5 * k + 2] * a_ik_2; \n"); 00481 source.append(" a_ik_2 = a_ik_2 - p * buf[5 * k + 4]; \n"); 00482 source.append(" } \n"); 00483 00484 source.append(" A[k * stride + i] = a_ik - p; \n"); 00485 source.append(" a_ik_1 = a_ik_1 - p * buf[5 * k + 3]; \n"); 00486 00487 source.append(" a_ik = a_ik_1; \n"); 00488 source.append(" a_ik_1 = a_ik_2; \n"); 00489 source.append(" } \n"); 00490 00491 source.append(" A[n * stride + i] = a_ik; \n"); 00492 source.append(" } \n"); 00493 00494 source.append("} \n"); 00495 } 00496 00497 00498 00499 00500 // main kernel class 00502 template <class NumericT> 00503 struct svd 00504 { 00505 static std::string program_name() 00506 { 00507 return viennacl::ocl::type_to_string<NumericT>::apply() + "_svd"; 00508 } 00509 00510 static void init(viennacl::ocl::context & ctx) 00511 { 00512 viennacl::ocl::DOUBLE_PRECISION_CHECKER<NumericT>::apply(ctx); 00513 std::string numeric_string = viennacl::ocl::type_to_string<NumericT>::apply(); 00514 00515 static std::map<cl_context, bool> init_done; 00516 if (!init_done[ctx.handle().get()]) 00517 { 00518 std::string source; 00519 source.reserve(1024); 00520 00521 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source); 00522 00523 // only generate for floating points (forces error for integers) 00524 if (numeric_string == "float" || numeric_string == "double") 00525 { 00526 //helper function used by multiple kernels: 00527 generate_svd_col_reduce_lcl_array(source, numeric_string); 00528 00529 //kernels: 00530 generate_svd_bidiag_pack(source, numeric_string); 00531 generate_svd_copy_col(source, numeric_string); 00532 generate_svd_copy_row(source, numeric_string); 00533 generate_svd_final_iter_update(source, numeric_string); 00534 generate_svd_givens_next(source, numeric_string); 00535 generate_svd_givens_prev(source, numeric_string); 00536 generate_svd_house_update_A_left(source, numeric_string); 00537 generate_svd_house_update_A_right(source, numeric_string); 00538 generate_svd_house_update_QL(source, numeric_string); 00539 generate_svd_house_update_QR(source, numeric_string); 00540 generate_svd_inverse_signs(source, numeric_string); 00541 generate_svd_transpose_inplace(source, numeric_string); 00542 generate_svd_update_qr_column(source, numeric_string); 00543 } 00544 00545 std::string prog_name = program_name(); 00546 #ifdef VIENNACL_BUILD_INFO 00547 std::cout << "Creating program " << prog_name << std::endl; 00548 #endif 00549 ctx.add_program(source, prog_name); 00550 init_done[ctx.handle().get()] = true; 00551 } //if 00552 } //init 00553 }; 00554 00555 } // namespace kernels 00556 } // namespace opencl 00557 } // namespace linalg 00558 } // namespace viennacl 00559 #endif 00560