ViennaCL - The Vienna Computing Library  1.5.2
viennacl/linalg/opencl/kernels/coordinate_matrix.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_COORDINATE_MATRIX_HPP
00002 #define VIENNACL_LINALG_OPENCL_KERNELS_COORDINATE_MATRIX_HPP
00003 
00004 #include "viennacl/tools/tools.hpp"
00005 #include "viennacl/ocl/kernel.hpp"
00006 #include "viennacl/ocl/platform.hpp"
00007 #include "viennacl/ocl/utils.hpp"
00008 
00009 #include "viennacl/linalg/opencl/common.hpp"
00010 
00013 namespace viennacl
00014 {
00015   namespace linalg
00016   {
00017     namespace opencl
00018     {
00019       namespace kernels
00020       {
00021 
00023 
00024         template <typename StringType>
00025         void generate_coordinate_matrix_vec_mul(StringType & source, std::string const & numeric_string)
00026         {
00027           source.append("__kernel void vec_mul( \n");
00028           source.append("  __global const uint2 * coords,  \n");//(row_index, column_index)
00029           source.append("  __global const "); source.append(numeric_string); source.append(" * elements, \n");
00030           source.append("  __global const uint  * group_boundaries, \n");
00031           source.append("  __global const "); source.append(numeric_string); source.append(" * x, \n");
00032           source.append("  uint4 layout_x, \n");
00033           source.append("  __global "); source.append(numeric_string); source.append(" * result, \n");
00034           source.append("  uint4 layout_result, \n");
00035           source.append("  __local unsigned int * shared_rows, \n");
00036           source.append("  __local "); source.append(numeric_string); source.append(" * inter_results) \n");
00037           source.append("{ \n");
00038           source.append("  uint2 tmp; \n");
00039           source.append("  "); source.append(numeric_string); source.append(" val; \n");
00040           source.append("  uint group_start = group_boundaries[get_group_id(0)]; \n");
00041           source.append("  uint group_end   = group_boundaries[get_group_id(0) + 1]; \n");
00042           source.append("  uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : 0; \n");   // -1 in order to have correct behavior if group_end - group_start == j * get_local_size(0)
00043 
00044           source.append("  uint local_index = 0; \n");
00045 
00046           source.append("  for (uint k = 0; k < k_end; ++k) { \n");
00047           source.append("    local_index = group_start + k * get_local_size(0) + get_local_id(0); \n");
00048 
00049           source.append("    tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n");
00050           source.append("    val = (local_index < group_end) ? elements[local_index] * x[tmp.y * layout_x.y + layout_x.x] : 0; \n");
00051 
00052           //check for carry from previous loop run:
00053           source.append("    if (get_local_id(0) == 0 && k > 0) { \n");
00054           source.append("      if (tmp.x == shared_rows[get_local_size(0)-1]) \n");
00055           source.append("        val += inter_results[get_local_size(0)-1]; \n");
00056           source.append("      else \n");
00057           source.append("        result[shared_rows[get_local_size(0)-1] * layout_result.y + layout_result.x] = inter_results[get_local_size(0)-1]; \n");
00058           source.append("    } \n");
00059 
00060           //segmented parallel reduction begin
00061           source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00062           source.append("    shared_rows[get_local_id(0)] = tmp.x; \n");
00063           source.append("    inter_results[get_local_id(0)] = val; \n");
00064           source.append("    "); source.append(numeric_string); source.append(" left = 0; \n");
00065           source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00066 
00067           source.append("    for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) { \n");
00068           source.append("      left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : 0; \n");
00069           source.append("      barrier(CLK_LOCAL_MEM_FENCE); \n");
00070           source.append("      inter_results[get_local_id(0)] += left; \n");
00071           source.append("      barrier(CLK_LOCAL_MEM_FENCE); \n");
00072           source.append("    } \n");
00073           //segmented parallel reduction end
00074 
00075           source.append("    if (local_index < group_end && get_local_id(0) < get_local_size(0) - 1 && \n");
00076           source.append("      shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1]) { \n");
00077           source.append("      result[tmp.x * layout_result.y + layout_result.x] = inter_results[get_local_id(0)]; \n");
00078           source.append("    } \n");
00079 
00080           source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00081           source.append("  }  \n"); //for k
00082 
00083           source.append("  if (local_index + 1 == group_end) \n");  //write results of last active entry (this may not necessarily be the case already)
00084           source.append("    result[tmp.x * layout_result.y + layout_result.x] = inter_results[get_local_id(0)]; \n");
00085           source.append("} \n");
00086 
00087         }
00088 
00089         namespace detail
00090         {
00092           template <typename StringType>
00093           void generate_coordinate_matrix_dense_matrix_mul(StringType & source, std::string const & numeric_string,
00094                                                            bool B_transposed, bool B_row_major, bool C_row_major)
00095           {
00096             source.append("__kernel void ");
00097             source.append(viennacl::linalg::opencl::detail::sparse_dense_matmult_kernel_name(B_transposed, B_row_major, C_row_major));
00098             source.append("( \n");
00099             source.append("  __global const uint2 * coords,  \n");//(row_index, column_index)
00100             source.append("  __global const "); source.append(numeric_string); source.append(" * elements, \n");
00101             source.append("  __global const uint  * group_boundaries, \n");
00102             source.append("  __global const "); source.append(numeric_string); source.append(" * d_mat, \n");
00103             source.append("  unsigned int d_mat_row_start, \n");
00104             source.append("  unsigned int d_mat_col_start, \n");
00105             source.append("  unsigned int d_mat_row_inc, \n");
00106             source.append("  unsigned int d_mat_col_inc, \n");
00107             source.append("  unsigned int d_mat_row_size, \n");
00108             source.append("  unsigned int d_mat_col_size, \n");
00109             source.append("  unsigned int d_mat_internal_rows, \n");
00110             source.append("  unsigned int d_mat_internal_cols, \n");
00111             source.append("  __global "); source.append(numeric_string); source.append(" * result, \n");
00112             source.append("  unsigned int result_row_start, \n");
00113             source.append("  unsigned int result_col_start, \n");
00114             source.append("  unsigned int result_row_inc, \n");
00115             source.append("  unsigned int result_col_inc, \n");
00116             source.append("  unsigned int result_row_size, \n");
00117             source.append("  unsigned int result_col_size, \n");
00118             source.append("  unsigned int result_internal_rows, \n");
00119             source.append("  unsigned int result_internal_cols, \n");
00120             source.append("  __local unsigned int * shared_rows, \n");
00121             source.append("  __local "); source.append(numeric_string); source.append(" * inter_results) \n");
00122             source.append("{ \n");
00123             source.append("  uint2 tmp; \n");
00124             source.append("  "); source.append(numeric_string); source.append(" val; \n");
00125             source.append("  uint group_start = group_boundaries[get_group_id(0)]; \n");
00126             source.append("  uint group_end   = group_boundaries[get_group_id(0) + 1]; \n");
00127             source.append("  uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : 0; \n");   // -1 in order to have correct behavior if group_end - group_start == j * get_local_size(0)
00128 
00129             source.append("  uint local_index = 0; \n");
00130 
00131             source.append("  for (uint result_col = 0; result_col < result_col_size; ++result_col) { \n");
00132             source.append("   for (uint k = 0; k < k_end; ++k) { \n");
00133             source.append("    local_index = group_start + k * get_local_size(0) + get_local_id(0); \n");
00134 
00135             source.append("    tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n");
00136             if (B_transposed && B_row_major)
00137               source.append("    val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start +      tmp.y * d_mat_col_inc ] : 0; \n");
00138             if (B_transposed && !B_row_major)
00139               source.append("    val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start + result_col * d_mat_row_inc)                       + (d_mat_col_start +      tmp.y * d_mat_col_inc) * d_mat_internal_rows ] : 0; \n");
00140             else if (!B_transposed && B_row_major)
00141               source.append("    val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start +      tmp.y * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + result_col * d_mat_col_inc ] : 0; \n");
00142             else
00143               source.append("    val = (local_index < group_end) ? elements[local_index] * d_mat[ (d_mat_row_start +      tmp.y * d_mat_row_inc)                       + (d_mat_col_start + result_col * d_mat_col_inc) * d_mat_internal_rows ] : 0; \n");
00144 
00145             //check for carry from previous loop run:
00146             source.append("    if (get_local_id(0) == 0 && k > 0) { \n");
00147             source.append("      if (tmp.x == shared_rows[get_local_size(0)-1]) \n");
00148             source.append("        val += inter_results[get_local_size(0)-1]; \n");
00149             source.append("      else \n");
00150             if (C_row_major)
00151               source.append("        result[(shared_rows[get_local_size(0)-1] * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_size(0)-1]; \n");
00152             else
00153               source.append("        result[(shared_rows[get_local_size(0)-1] * result_row_inc + result_row_start)                        + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_size(0)-1]; \n");
00154             source.append("    } \n");
00155 
00156             //segmented parallel reduction begin
00157             source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00158             source.append("    shared_rows[get_local_id(0)] = tmp.x; \n");
00159             source.append("    inter_results[get_local_id(0)] = val; \n");
00160             source.append("    "); source.append(numeric_string); source.append(" left = 0; \n");
00161             source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00162 
00163             source.append("    for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) { \n");
00164             source.append("      left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : 0; \n");
00165             source.append("      barrier(CLK_LOCAL_MEM_FENCE); \n");
00166             source.append("      inter_results[get_local_id(0)] += left; \n");
00167             source.append("      barrier(CLK_LOCAL_MEM_FENCE); \n");
00168             source.append("    } \n");
00169             //segmented parallel reduction end
00170 
00171             source.append("    if (local_index < group_end && get_local_id(0) < get_local_size(0) - 1 && \n");
00172             source.append("      shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1]) { \n");
00173             if (C_row_major)
00174               source.append("      result[(tmp.x * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_id(0)]; \n");
00175             else
00176               source.append("      result[(tmp.x * result_row_inc + result_row_start)                        + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_id(0)]; \n");
00177             source.append("    } \n");
00178 
00179             source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00180             source.append("   }  \n"); //for k
00181 
00182             source.append("   if (local_index + 1 == group_end) \n");  //write results of last active entry (this may not necessarily be the case already)
00183             if (C_row_major)
00184               source.append("    result[(tmp.x  * result_row_inc + result_row_start) * result_internal_cols + result_col_start + result_col * result_col_inc ] = inter_results[get_local_id(0)]; \n");
00185             else
00186               source.append("    result[(tmp.x  * result_row_inc + result_row_start)                        + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = inter_results[get_local_id(0)]; \n");
00187             source.append("  } \n"); //for result_col
00188             source.append("} \n");
00189 
00190           }
00191         }
00192 
00193         template <typename StringType>
00194         void generate_coordinate_matrix_dense_matrix_multiplication(StringType & source, std::string const & numeric_string)
00195         {
00196           detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, false, false, false);
00197           detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, false, false,  true);
00198           detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, false,  true, false);
00199           detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, false,  true,  true);
00200 
00201           detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, true, false, false);
00202           detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, true, false,  true);
00203           detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, true,  true, false);
00204           detail::generate_coordinate_matrix_dense_matrix_mul(source, numeric_string, true,  true,  true);
00205         }
00206 
00207         template <typename StringType>
00208         void generate_coordinate_matrix_row_info_extractor(StringType & source, std::string const & numeric_string)
00209         {
00210           source.append("__kernel void row_info_extractor( \n");
00211           source.append("          __global const uint2 * coords,  \n");//(row_index, column_index)
00212           source.append("          __global const "); source.append(numeric_string); source.append(" * elements, \n");
00213           source.append("          __global const uint  * group_boundaries, \n");
00214           source.append("          __global "); source.append(numeric_string); source.append(" * result, \n");
00215           source.append("          unsigned int option, \n");
00216           source.append("          __local unsigned int * shared_rows, \n");
00217           source.append("          __local "); source.append(numeric_string); source.append(" * inter_results) \n");
00218           source.append("{ \n");
00219           source.append("  uint2 tmp; \n");
00220           source.append("  "); source.append(numeric_string); source.append(" val; \n");
00221           source.append("  uint last_index  = get_local_size(0) - 1; \n");
00222           source.append("  uint group_start = group_boundaries[get_group_id(0)]; \n");
00223           source.append("  uint group_end   = group_boundaries[get_group_id(0) + 1]; \n");
00224           source.append("  uint k_end = (group_end > group_start) ? 1 + (group_end - group_start - 1) / get_local_size(0) : ("); source.append(numeric_string); source.append(")0; \n");   // -1 in order to have correct behavior if group_end - group_start == j * get_local_size(0)
00225 
00226           source.append("  uint local_index = 0; \n");
00227 
00228           source.append("  for (uint k = 0; k < k_end; ++k) \n");
00229           source.append("  { \n");
00230           source.append("    local_index = group_start + k * get_local_size(0) + get_local_id(0); \n");
00231 
00232           source.append("    tmp = (local_index < group_end) ? coords[local_index] : (uint2) 0; \n");
00233           source.append("    val = (local_index < group_end && (option != 3 || tmp.x == tmp.y) ) ? elements[local_index] : 0; \n");
00234 
00235               //check for carry from previous loop run:
00236           source.append("    if (get_local_id(0) == 0 && k > 0) \n");
00237           source.append("    { \n");
00238           source.append("      if (tmp.x == shared_rows[last_index]) \n");
00239           source.append("      { \n");
00240           source.append("        switch (option) \n");
00241           source.append("        { \n");
00242           source.append("          case 0: \n"); //inf-norm
00243           source.append("          case 3: \n"); //diagonal entry
00244           source.append("            val = max(val, fabs(inter_results[last_index])); \n");
00245           source.append("            break; \n");
00246 
00247           source.append("          case 1: \n"); //1-norm
00248           source.append("            val = fabs(val) + inter_results[last_index]; \n");
00249           source.append("            break; \n");
00250 
00251           source.append("          case 2: \n"); //2-norm
00252           source.append("            val = sqrt(val * val + inter_results[last_index]); \n");
00253           source.append("            break; \n");
00254 
00255           source.append("          default: \n");
00256           source.append("            break; \n");
00257           source.append("        } \n");
00258           source.append("      } \n");
00259           source.append("      else \n");
00260           source.append("      { \n");
00261           source.append("        switch (option) \n");
00262           source.append("        { \n");
00263           source.append("          case 0: \n"); //inf-norm
00264           source.append("          case 1: \n"); //1-norm
00265           source.append("          case 3: \n"); //diagonal entry
00266           source.append("            result[shared_rows[last_index]] = inter_results[last_index]; \n");
00267           source.append("            break; \n");
00268 
00269           source.append("          case 2: \n"); //2-norm
00270           source.append("            result[shared_rows[last_index]] = sqrt(inter_results[last_index]); \n");
00271           source.append("          default: \n");
00272           source.append("            break; \n");
00273           source.append("        } \n");
00274           source.append("      } \n");
00275           source.append("    } \n");
00276 
00277               //segmented parallel reduction begin
00278           source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00279           source.append("    shared_rows[get_local_id(0)] = tmp.x; \n");
00280           source.append("    switch (option) \n");
00281           source.append("    { \n");
00282           source.append("      case 0: \n");
00283           source.append("      case 3: \n");
00284           source.append("        inter_results[get_local_id(0)] = val; \n");
00285           source.append("        break; \n");
00286           source.append("      case 1: \n");
00287           source.append("        inter_results[get_local_id(0)] = fabs(val); \n");
00288           source.append("        break; \n");
00289           source.append("      case 2: \n");
00290           source.append("        inter_results[get_local_id(0)] = val * val; \n");
00291           source.append("      default: \n");
00292           source.append("        break; \n");
00293           source.append("    } \n");
00294           source.append("    "); source.append(numeric_string); source.append(" left = 0; \n");
00295           source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00296 
00297           source.append("    for (unsigned int stride = 1; stride < get_local_size(0); stride *= 2) \n");
00298           source.append("    { \n");
00299           source.append("      left = (get_local_id(0) >= stride && tmp.x == shared_rows[get_local_id(0) - stride]) ? inter_results[get_local_id(0) - stride] : ("); source.append(numeric_string); source.append(")0; \n");
00300           source.append("      barrier(CLK_LOCAL_MEM_FENCE); \n");
00301           source.append("      switch (option) \n");
00302           source.append("      { \n");
00303           source.append("        case 0: \n"); //inf-norm
00304           source.append("        case 3: \n"); //diagonal entry
00305           source.append("          inter_results[get_local_id(0)] = max(inter_results[get_local_id(0)], left); \n");
00306           source.append("          break; \n");
00307 
00308           source.append("        case 1: \n"); //1-norm
00309           source.append("          inter_results[get_local_id(0)] += left; \n");
00310           source.append("          break; \n");
00311 
00312           source.append("        case 2: \n"); //2-norm
00313           source.append("          inter_results[get_local_id(0)] += left; \n");
00314           source.append("          break; \n");
00315 
00316           source.append("        default: \n");
00317           source.append("          break; \n");
00318           source.append("      } \n");
00319           source.append("      barrier(CLK_LOCAL_MEM_FENCE); \n");
00320           source.append("    } \n");
00321               //segmented parallel reduction end
00322 
00323           source.append("    if (get_local_id(0) != last_index && \n");
00324           source.append("        shared_rows[get_local_id(0)] != shared_rows[get_local_id(0) + 1] && \n");
00325           source.append("        inter_results[get_local_id(0)] != 0) \n");
00326           source.append("    { \n");
00327           source.append("      result[tmp.x] = (option == 2) ? sqrt(inter_results[get_local_id(0)]) : inter_results[get_local_id(0)]; \n");
00328           source.append("    } \n");
00329 
00330           source.append("    barrier(CLK_LOCAL_MEM_FENCE); \n");
00331           source.append("  } \n"); //for k
00332 
00333           source.append("  if (get_local_id(0) == last_index && inter_results[last_index] != 0) \n");
00334           source.append("    result[tmp.x] = (option == 2) ? sqrt(inter_results[last_index]) : inter_results[last_index]; \n");
00335           source.append("} \n");
00336         }
00337 
00339 
00340         // main kernel class
00342         template <typename NumericT>
00343         struct coordinate_matrix
00344         {
00345           static std::string program_name()
00346           {
00347             return viennacl::ocl::type_to_string<NumericT>::apply() + "_coordinate_matrix";
00348           }
00349 
00350           static void init(viennacl::ocl::context & ctx)
00351           {
00352             viennacl::ocl::DOUBLE_PRECISION_CHECKER<NumericT>::apply(ctx);
00353             std::string numeric_string = viennacl::ocl::type_to_string<NumericT>::apply();
00354 
00355             static std::map<cl_context, bool> init_done;
00356             if (!init_done[ctx.handle().get()])
00357             {
00358               std::string source;
00359               source.reserve(1024);
00360 
00361               viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
00362 
00363               generate_coordinate_matrix_vec_mul(source, numeric_string);
00364               generate_coordinate_matrix_dense_matrix_multiplication(source, numeric_string);
00365               generate_coordinate_matrix_row_info_extractor(source, numeric_string);
00366 
00367               std::string prog_name = program_name();
00368               #ifdef VIENNACL_BUILD_INFO
00369               std::cout << "Creating program " << prog_name << std::endl;
00370               #endif
00371               ctx.add_program(source, prog_name);
00372               init_done[ctx.handle().get()] = true;
00373             } //if
00374           } //init
00375         };
00376 
00377       }  // namespace kernels
00378     }  // namespace opencl
00379   }  // namespace linalg
00380 }  // namespace viennacl
00381 #endif
00382