ViennaCL - The Vienna Computing Library  1.5.2
viennacl/linalg/opencl/kernels/matrix_solve.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_SOLVE_HPP
00002 #define VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_SOLVE_HPP
00003 
00004 #include "viennacl/tools/tools.hpp"
00005 #include "viennacl/ocl/kernel.hpp"
00006 #include "viennacl/ocl/platform.hpp"
00007 #include "viennacl/ocl/utils.hpp"
00008 
00009 #include "viennacl/linalg/opencl/kernels/matrix.hpp"
00010 
00013 namespace viennacl
00014 {
00015   namespace linalg
00016   {
00017     namespace opencl
00018     {
00019       namespace kernels
00020       {
00021 
00022         template <typename StringType>
00023         void generate_matrix_solve_blas3(StringType & source, std::string const & numeric_string,
00024                                          bool row_major_A, bool row_major_B,
00025                                          bool transpose_A, bool transpose_B,
00026                                          bool upper_solve, bool unit_diagonal)
00027         {
00028           //start OpenCL code:
00029           source.append("__kernel void ");
00030           if (transpose_A)
00031             source.append("trans_");
00032           if (unit_diagonal)
00033             source.append("unit_");
00034           if (upper_solve)
00035             source.append("upper_");
00036           else
00037             source.append("lower_");
00038           if (transpose_B)
00039             source.append("trans_");
00040           source.append("solve");
00041 
00042           source.append("( \n");
00043           source.append("          __global const "); source.append(numeric_string); source.append(" * A, \n");
00044           source.append("          unsigned int A_start1, unsigned int A_start2, \n");
00045           source.append("          unsigned int A_inc1,   unsigned int A_inc2, \n");
00046           source.append("          unsigned int A_size1,  unsigned int A_size2, \n");
00047           source.append("          unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
00048           source.append("          __global "); source.append(numeric_string); source.append(" * B, \n");
00049           source.append("          unsigned int B_start1, unsigned int B_start2, \n");
00050           source.append("          unsigned int B_inc1,   unsigned int B_inc2, \n");
00051           source.append("          unsigned int B_size1,  unsigned int B_size2, \n");
00052           source.append("          unsigned int B_internal_size1, unsigned int B_internal_size2) { \n");
00053           source.append("  "); source.append(numeric_string); source.append(" temp;  \n");
00054           if (upper_solve)
00055           {
00056             //Note: A is square, thus A_rows == A_cols and no dispatch for transposedness needed
00057             source.append("  for (unsigned int row_cnt = 0; row_cnt < A_size1; ++row_cnt)  \n");
00058             source.append("  {  \n");
00059             source.append("    unsigned int row = A_size1 - 1 - row_cnt; \n");
00060           }
00061           else //lower triangular solve
00062           {
00063             source.append("  for (unsigned int row = 0; row < A_size1; ++row) \n");
00064             source.append("  { \n");
00065           }
00066 
00067           if (!unit_diagonal)
00068           {
00069             source.append("    barrier(CLK_GLOBAL_MEM_FENCE); \n");
00070             source.append("    if (get_local_id(0) == 0)  \n");
00071             //Note: A is square, thus A_internal_rows == A_internal_cols and no dispatch for transposedness needed
00072             if (row_major_B && transpose_B)
00073               source.append("      B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (row * B_inc2 + B_start2)] /= ");
00074             else if (row_major_B && !transpose_B)
00075               source.append("      B[(row * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)] /= ");
00076             else if (!row_major_B && transpose_B)
00077               source.append("      B[(get_group_id(0) * B_inc1 + B_start1) + (row * B_inc2 + B_start2) * B_internal_size1] /= ");
00078             else if (!row_major_B && !transpose_B)
00079               source.append("      B[(row * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1] /= ");
00080 
00081             if (row_major_A)
00082               source.append("A[(row * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n");
00083             else
00084               source.append("A[(row * A_inc1 + A_start1) + (row * A_inc2 + A_start2)*A_internal_size1]; \n");
00085           }
00086 
00087           source.append("    barrier(CLK_GLOBAL_MEM_FENCE); \n");
00088 
00089           if (row_major_B && transpose_B)
00090             source.append("    temp = B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (row * B_inc2 + B_start2)]; \n");
00091           else if (row_major_B && !transpose_B)
00092             source.append("    temp = B[(row * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)]; \n");
00093           else if (!row_major_B && transpose_B)
00094             source.append("    temp = B[(get_group_id(0) * B_inc1 + B_start1) + (row * B_inc2 + B_start2) * B_internal_size1]; \n");
00095           else if (!row_major_B && !transpose_B)
00096             source.append("    temp = B[(row * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1]; \n");
00097 
00098           source.append("    //eliminate column of op(A) with index 'row' in parallel: \n");
00099           if (upper_solve)
00100             source.append("    for  (unsigned int elim = get_local_id(0); elim < row; elim += get_local_size(0)) \n");
00101           else
00102             source.append("    for  (unsigned int elim = row + get_local_id(0) + 1; elim < A_size1; elim += get_local_size(0)) \n");
00103 
00104           if (row_major_B && transpose_B)
00105             source.append("      B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (elim * B_inc2 + B_start2)] -= temp * ");
00106           else if (row_major_B && !transpose_B)
00107             source.append("      B[(elim * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)] -= temp * ");
00108           else if (!row_major_B && transpose_B)
00109             source.append("      B[(get_group_id(0) * B_inc1 + B_start1) + (elim * B_inc2 + B_start2) * B_internal_size1] -= temp * ");
00110           else if (!row_major_B && !transpose_B)
00111             source.append("      B[(elim * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1] -= temp * ");
00112 
00113           if (row_major_A && transpose_A)
00114             source.append("A[(row * A_inc1 + A_start1) * A_internal_size2 + (elim * A_inc2 + A_start2)]; \n");
00115           else if (row_major_A && !transpose_A)
00116             source.append("A[(elim * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n");
00117           else if (!row_major_A && transpose_A)
00118             source.append("A[(row * A_inc1 + A_start1) + (elim * A_inc2 + A_start2) * A_internal_size1]; \n");
00119           else if (!row_major_A && !transpose_A)
00120             source.append("A[(elim * A_inc1 + A_start1) + (row * A_inc2 + A_start2) * A_internal_size1]; \n");
00121 
00122           source.append("   } \n");
00123           source.append("} \n");
00124         }
00125 
00126 
00127         // main kernel class
00133         template <class NumericT, typename F1, typename F2>
00134         struct matrix_solve
00135         {
00136           static std::string program_name()
00137           {
00138             return viennacl::ocl::type_to_string<NumericT>::apply() + "_matrix_solve_" + detail::type_to_string(F1()) + detail::type_to_string(F2());
00139           }
00140 
00141           static void init(viennacl::ocl::context & ctx)
00142           {
00143             viennacl::ocl::DOUBLE_PRECISION_CHECKER<NumericT>::apply(ctx);
00144             std::string numeric_string = viennacl::ocl::type_to_string<NumericT>::apply();
00145             bool matrix_row_major = viennacl::is_row_major<F1>::value;
00146             bool rhs_row_major    = viennacl::is_row_major<F2>::value;
00147 
00148 
00149             static std::map<cl_context, bool> init_done;
00150             if (!init_done[ctx.handle().get()])
00151             {
00152               std::string source;
00153               source.reserve(8192);
00154 
00155               viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
00156 
00157               // only generate for floating points (forces error for integers)
00158               if (numeric_string == "float" || numeric_string == "double")
00159               {
00160                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00161                                             false, false, false, false);
00162                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00163                                             false, false, false, true);
00164                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00165                                             false, false, true, false);
00166                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00167                                             false, false, true, true);
00168 
00169                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00170                                             false, true, false, false);
00171                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00172                                             false, true, false, true);
00173                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00174                                             false, true, true, false);
00175                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00176                                             false, true, true, true);
00177 
00178                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00179                                             true, false, false, false);
00180                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00181                                             true, false, false, true);
00182                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00183                                             true, false, true, false);
00184                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00185                                             true, false, true, true);
00186 
00187                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00188                                             true, true, false, false);
00189                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00190                                             true, true, false, true);
00191                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00192                                             true, true, true, false);
00193                 generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
00194                                             true, true, true, true);
00195               }
00196 
00197               std::string prog_name = program_name();
00198               #ifdef VIENNACL_BUILD_INFO
00199               std::cout << "Creating program " << prog_name << std::endl;
00200               #endif
00201               ctx.add_program(source, prog_name);
00202               init_done[ctx.handle().get()] = true;
00203             } //if
00204           } //init
00205         };
00206 
00207       }  // namespace kernels
00208     }  // namespace opencl
00209   }  // namespace linalg
00210 }  // namespace viennacl
00211 #endif
00212