ViennaCL - The Vienna Computing Library
1.5.2
|
00001 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_HYB_MATRIX_HPP 00002 #define VIENNACL_LINALG_OPENCL_KERNELS_HYB_MATRIX_HPP 00003 00004 #include "viennacl/tools/tools.hpp" 00005 #include "viennacl/ocl/kernel.hpp" 00006 #include "viennacl/ocl/platform.hpp" 00007 #include "viennacl/ocl/utils.hpp" 00008 00009 #include "viennacl/linalg/opencl/common.hpp" 00010 00013 namespace viennacl 00014 { 00015 namespace linalg 00016 { 00017 namespace opencl 00018 { 00019 namespace kernels 00020 { 00021 00023 00024 template <typename StringType> 00025 void generate_hyb_vec_mul(StringType & source, std::string const & numeric_string) 00026 { 00027 source.append("__kernel void vec_mul( \n"); 00028 source.append(" const __global int* ell_coords, \n"); 00029 source.append(" const __global "); source.append(numeric_string); source.append("* ell_elements, \n"); 00030 source.append(" const __global uint* csr_rows, \n"); 00031 source.append(" const __global uint* csr_cols, \n"); 00032 source.append(" const __global "); source.append(numeric_string); source.append("* csr_elements, \n"); 00033 source.append(" const __global "); source.append(numeric_string); source.append(" * x, \n"); 00034 source.append(" uint4 layout_x, \n"); 00035 source.append(" __global "); source.append(numeric_string); source.append(" * result, \n"); 00036 source.append(" uint4 layout_result, \n"); 00037 source.append(" unsigned int row_num, \n"); 00038 source.append(" unsigned int internal_row_num, \n"); 00039 source.append(" unsigned int items_per_row, \n"); 00040 source.append(" unsigned int aligned_items_per_row) \n"); 00041 source.append("{ \n"); 00042 source.append(" uint glb_id = get_global_id(0); \n"); 00043 source.append(" uint glb_sz = get_global_size(0); \n"); 00044 00045 source.append(" for(uint row_id = glb_id; row_id < row_num; row_id += glb_sz) { \n"); 00046 source.append(" "); source.append(numeric_string); source.append(" sum = 0; \n"); 00047 00048 source.append(" uint offset = row_id; \n"); 00049 source.append(" for(uint item_id = 0; item_id < items_per_row; item_id++, offset += internal_row_num) { \n"); 00050 source.append(" "); source.append(numeric_string); source.append(" val = ell_elements[offset]; \n"); 00051 00052 source.append(" if(val != ("); source.append(numeric_string); source.append(")0) { \n"); 00053 source.append(" int col = ell_coords[offset]; \n"); 00054 source.append(" sum += (x[col * layout_x.y + layout_x.x] * val); \n"); 00055 source.append(" } \n"); 00056 00057 source.append(" } \n"); 00058 00059 source.append(" uint col_begin = csr_rows[row_id]; \n"); 00060 source.append(" uint col_end = csr_rows[row_id + 1]; \n"); 00061 00062 source.append(" for(uint item_id = col_begin; item_id < col_end; item_id++) { \n"); 00063 source.append(" sum += (x[csr_cols[item_id] * layout_x.y + layout_x.x] * csr_elements[item_id]); \n"); 00064 source.append(" } \n"); 00065 00066 source.append(" result[row_id * layout_result.y + layout_result.x] = sum; \n"); 00067 source.append(" } \n"); 00068 source.append("} \n"); 00069 } 00070 00071 namespace detail 00072 { 00073 template <typename StringType> 00074 void generate_hyb_matrix_dense_matrix_mul(StringType & source, std::string const & numeric_string, 00075 bool B_transposed, bool B_row_major, bool C_row_major) 00076 { 00077 source.append("__kernel void "); 00078 source.append(viennacl::linalg::opencl::detail::sparse_dense_matmult_kernel_name(B_transposed, B_row_major, C_row_major)); 00079 source.append("( \n"); 00080 source.append(" const __global int* ell_coords, \n"); 00081 source.append(" const __global "); source.append(numeric_string); source.append("* ell_elements, \n"); 00082 source.append(" const __global uint* csr_rows, \n"); 00083 source.append(" const __global uint* csr_cols, \n"); 00084 source.append(" const __global "); source.append(numeric_string); source.append("* csr_elements, \n"); 00085 source.append(" unsigned int row_num, \n"); 00086 source.append(" unsigned int internal_row_num, \n"); 00087 source.append(" unsigned int items_per_row, \n"); 00088 source.append(" unsigned int aligned_items_per_row, \n"); 00089 source.append(" __global const "); source.append(numeric_string); source.append("* d_mat, \n"); 00090 source.append(" unsigned int d_mat_row_start, \n"); 00091 source.append(" unsigned int d_mat_col_start, \n"); 00092 source.append(" unsigned int d_mat_row_inc, \n"); 00093 source.append(" unsigned int d_mat_col_inc, \n"); 00094 source.append(" unsigned int d_mat_row_size, \n"); 00095 source.append(" unsigned int d_mat_col_size, \n"); 00096 source.append(" unsigned int d_mat_internal_rows, \n"); 00097 source.append(" unsigned int d_mat_internal_cols, \n"); 00098 source.append(" __global "); source.append(numeric_string); source.append(" * result, \n"); 00099 source.append(" unsigned int result_row_start, \n"); 00100 source.append(" unsigned int result_col_start, \n"); 00101 source.append(" unsigned int result_row_inc, \n"); 00102 source.append(" unsigned int result_col_inc, \n"); 00103 source.append(" unsigned int result_row_size, \n"); 00104 source.append(" unsigned int result_col_size, \n"); 00105 source.append(" unsigned int result_internal_rows, \n"); 00106 source.append(" unsigned int result_internal_cols) { \n"); 00107 00108 source.append(" uint glb_id = get_global_id(0); \n"); 00109 source.append(" uint glb_sz = get_global_size(0); \n"); 00110 00111 source.append(" for(uint result_col = 0; result_col < result_col_size; ++result_col) { \n"); 00112 source.append(" for(uint row_id = glb_id; row_id < row_num; row_id += glb_sz) { \n"); 00113 source.append(" "); source.append(numeric_string); source.append(" sum = 0; \n"); 00114 00115 source.append(" uint offset = row_id; \n"); 00116 source.append(" for(uint item_id = 0; item_id < items_per_row; item_id++, offset += internal_row_num) { \n"); 00117 source.append(" "); source.append(numeric_string); source.append(" val = ell_elements[offset]; \n"); 00118 00119 source.append(" if(val != ("); source.append(numeric_string); source.append(")0) { \n"); 00120 source.append(" int col = ell_coords[offset]; \n"); 00121 if (B_transposed && B_row_major) 00122 source.append(" sum += d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + col * d_mat_col_inc ] * val; \n"); 00123 else if (B_transposed && !B_row_major) 00124 source.append(" sum += d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) + (d_mat_col_start + col * d_mat_col_inc) * d_mat_internal_rows ] * val; \n"); 00125 else if (!B_transposed && B_row_major) 00126 source.append(" sum += d_mat[ (d_mat_row_start + col * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + result_col * d_mat_col_inc ] * val; \n"); 00127 else 00128 source.append(" sum += d_mat[ (d_mat_row_start + col * d_mat_row_inc) + (d_mat_col_start + result_col * d_mat_col_inc) * d_mat_internal_rows ] * val; \n"); 00129 source.append(" } \n"); 00130 00131 source.append(" } \n"); 00132 00133 source.append(" uint col_begin = csr_rows[row_id]; \n"); 00134 source.append(" uint col_end = csr_rows[row_id + 1]; \n"); 00135 00136 source.append(" for(uint item_id = col_begin; item_id < col_end; item_id++) { \n"); 00137 if (B_transposed && B_row_major) 00138 source.append(" sum += d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + csr_cols[item_id] * d_mat_col_inc ] * csr_elements[item_id]; \n"); 00139 else if (B_transposed && !B_row_major) 00140 source.append(" sum += d_mat[ (d_mat_row_start + result_col * d_mat_row_inc) + (d_mat_col_start + csr_cols[item_id] * d_mat_col_inc) * d_mat_internal_rows ] * csr_elements[item_id]; \n"); 00141 else if (!B_transposed && B_row_major) 00142 source.append(" sum += d_mat[ (d_mat_row_start + csr_cols[item_id] * d_mat_row_inc) * d_mat_internal_cols + d_mat_col_start + result_col * d_mat_col_inc ] * csr_elements[item_id]; \n"); 00143 else 00144 source.append(" sum += d_mat[ (d_mat_row_start + csr_cols[item_id] * d_mat_row_inc) + (d_mat_col_start + result_col * d_mat_col_inc) * d_mat_internal_rows ] * csr_elements[item_id]; \n"); 00145 source.append(" } \n"); 00146 00147 if (C_row_major) 00148 source.append(" result[ (result_row_start + row_id * result_row_inc) * result_internal_cols + result_col_start + result_col * result_col_inc ] = sum; \n"); 00149 else 00150 source.append(" result[ (result_row_start + row_id * result_row_inc) + (result_col_start + result_col * result_col_inc) * result_internal_rows ] = sum; \n"); 00151 source.append(" } \n"); 00152 source.append(" } \n"); 00153 source.append("} \n"); 00154 } 00155 } 00156 00157 template <typename StringType> 00158 void generate_hyb_matrix_dense_matrix_multiplication(StringType & source, std::string const & numeric_string) 00159 { 00160 detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, false, false, false); 00161 detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, false, false, true); 00162 detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, false, true, false); 00163 detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, false, true, true); 00164 00165 detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, true, false, false); 00166 detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, true, false, true); 00167 detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, true, true, false); 00168 detail::generate_hyb_matrix_dense_matrix_mul(source, numeric_string, true, true, true); 00169 } 00170 00172 00173 // main kernel class 00175 template <typename NumericT> 00176 struct hyb_matrix 00177 { 00178 static std::string program_name() 00179 { 00180 return viennacl::ocl::type_to_string<NumericT>::apply() + "_hyb_matrix"; 00181 } 00182 00183 static void init(viennacl::ocl::context & ctx) 00184 { 00185 viennacl::ocl::DOUBLE_PRECISION_CHECKER<NumericT>::apply(ctx); 00186 std::string numeric_string = viennacl::ocl::type_to_string<NumericT>::apply(); 00187 00188 static std::map<cl_context, bool> init_done; 00189 if (!init_done[ctx.handle().get()]) 00190 { 00191 std::string source; 00192 source.reserve(1024); 00193 00194 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source); 00195 00196 generate_hyb_vec_mul(source, numeric_string); 00197 generate_hyb_matrix_dense_matrix_multiplication(source, numeric_string); 00198 00199 std::string prog_name = program_name(); 00200 #ifdef VIENNACL_BUILD_INFO 00201 std::cout << "Creating program " << prog_name << std::endl; 00202 #endif 00203 ctx.add_program(source, prog_name); 00204 init_done[ctx.handle().get()] = true; 00205 } //if 00206 } //init 00207 }; 00208 00209 } // namespace kernels 00210 } // namespace opencl 00211 } // namespace linalg 00212 } // namespace viennacl 00213 #endif 00214