ViennaCL - The Vienna Computing Library  1.5.2
viennacl/linalg/opencl/kernels/vector_element.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_VECTOR_ELEMENT_HPP
00002 #define VIENNACL_LINALG_OPENCL_KERNELS_VECTOR_ELEMENT_HPP
00003 
00004 #include "viennacl/tools/tools.hpp"
00005 #include "viennacl/ocl/kernel.hpp"
00006 #include "viennacl/ocl/platform.hpp"
00007 #include "viennacl/ocl/utils.hpp"
00008 
00011 namespace viennacl
00012 {
00013   namespace linalg
00014   {
00015     namespace opencl
00016     {
00017       namespace kernels
00018       {
00019 
00021 
00022 
00023         //generate code for C = op1(A) * op2(B), where A, B, C can have different storage layouts and opX(D) = D or trans(D)
00024         template <typename StringType>
00025         void generate_vector_unary_element_ops(StringType & source, std::string const & numeric_string,
00026                                                std::string const & funcname, std::string const & op, std::string const & op_name)
00027         {
00028           source.append("__kernel void "); source.append(funcname); source.append("_"); source.append(op_name); source.append("(\n");
00029           source.append("    __global "); source.append(numeric_string); source.append(" * vec1, \n");
00030           source.append("    uint4 size1, \n");
00031           source.append("    __global "); source.append(numeric_string); source.append(" * vec2, \n");
00032           source.append("    uint4 size2) { \n");
00033           source.append("  for (unsigned int i = get_global_id(0); i < size1.z; i += get_global_size(0)) \n");
00034           source.append("    vec1[i*size1.y+size1.x] "); source.append(op); source.append(" "); source.append(funcname); source.append("(vec2[i*size2.y+size2.x]); \n");
00035           source.append("} \n");
00036         }
00037 
00038         template <typename StringType>
00039         void generate_vector_unary_element_ops(StringType & source, std::string const & numeric_string, std::string const & funcname)
00040         {
00041           generate_vector_unary_element_ops(source, numeric_string, funcname, "=", "assign");
00042           //generate_vector_unary_element_ops(source, numeric_string, funcname, "+=", "plus");
00043           //generate_vector_unary_element_ops(source, numeric_string, funcname, "-=", "minus");
00044         }
00045 
00046         template <typename StringType>
00047         void generate_vector_binary_element_ops(StringType & source, std::string const & numeric_string)
00048         {
00049           // generic kernel for the vector operation v1 = alpha * v2 + beta * v3, where v1, v2, v3 are not necessarily distinct vectors
00050           source.append("__kernel void element_op( \n");
00051           source.append("    __global "); source.append(numeric_string); source.append(" * vec1, \n");
00052           source.append("    unsigned int start1, \n");
00053           source.append("    unsigned int inc1, \n");
00054           source.append("    unsigned int size1, \n");
00055 
00056           source.append("    __global const "); source.append(numeric_string); source.append(" * vec2, \n");
00057           source.append("    unsigned int start2, \n");
00058           source.append("    unsigned int inc2, \n");
00059 
00060           source.append("    __global const "); source.append(numeric_string); source.append(" * vec3, \n");
00061           source.append("   unsigned int start3, \n");
00062           source.append("   unsigned int inc3, \n");
00063 
00064           source.append("   unsigned int op_type) \n"); //0: product, 1: division, 2: power
00065           source.append("{ \n");
00066           if (numeric_string == "float" || numeric_string == "double")
00067           {
00068             source.append("  if (op_type == 2) \n");
00069             source.append("  { \n");
00070             source.append("    for (unsigned int i = get_global_id(0); i < size1; i += get_global_size(0)) \n");
00071             source.append("      vec1[i*inc1+start1] = pow(vec2[i*inc2+start2], vec3[i*inc3+start3]); \n");
00072             source.append("  } else ");
00073           }
00074           source.append("  if (op_type == 1) \n");
00075           source.append("  { \n");
00076           source.append("    for (unsigned int i = get_global_id(0); i < size1; i += get_global_size(0)) \n");
00077           source.append("      vec1[i*inc1+start1] = vec2[i*inc2+start2] / vec3[i*inc3+start3]; \n");
00078           source.append("  } \n");
00079           source.append("  else if (op_type == 0)\n");
00080           source.append("  { \n");
00081           source.append("    for (unsigned int i = get_global_id(0); i < size1; i += get_global_size(0)) \n");
00082           source.append("      vec1[i*inc1+start1] = vec2[i*inc2+start2] * vec3[i*inc3+start3]; \n");
00083           source.append("  } \n");
00084           source.append("} \n");
00085         }
00086 
00088 
00089         // main kernel class
00091         template <class TYPE>
00092         struct vector_element
00093         {
00094           static std::string program_name()
00095           {
00096             return viennacl::ocl::type_to_string<TYPE>::apply() + "_vector_element";
00097           }
00098 
00099           static void init(viennacl::ocl::context & ctx)
00100           {
00101             viennacl::ocl::DOUBLE_PRECISION_CHECKER<TYPE>::apply(ctx);
00102             std::string numeric_string = viennacl::ocl::type_to_string<TYPE>::apply();
00103 
00104             static std::map<cl_context, bool> init_done;
00105             if (!init_done[ctx.handle().get()])
00106             {
00107               std::string source;
00108               source.reserve(8192);
00109 
00110               viennacl::ocl::append_double_precision_pragma<TYPE>(ctx, source);
00111 
00112               // unary operations
00113               if (numeric_string == "float" || numeric_string == "double")
00114               {
00115                 generate_vector_unary_element_ops(source, numeric_string, "acos");
00116                 generate_vector_unary_element_ops(source, numeric_string, "asin");
00117                 generate_vector_unary_element_ops(source, numeric_string, "atan");
00118                 generate_vector_unary_element_ops(source, numeric_string, "ceil");
00119                 generate_vector_unary_element_ops(source, numeric_string, "cos");
00120                 generate_vector_unary_element_ops(source, numeric_string, "cosh");
00121                 generate_vector_unary_element_ops(source, numeric_string, "exp");
00122                 generate_vector_unary_element_ops(source, numeric_string, "fabs");
00123                 generate_vector_unary_element_ops(source, numeric_string, "floor");
00124                 generate_vector_unary_element_ops(source, numeric_string, "log");
00125                 generate_vector_unary_element_ops(source, numeric_string, "log10");
00126                 generate_vector_unary_element_ops(source, numeric_string, "sin");
00127                 generate_vector_unary_element_ops(source, numeric_string, "sinh");
00128                 generate_vector_unary_element_ops(source, numeric_string, "sqrt");
00129                 generate_vector_unary_element_ops(source, numeric_string, "tan");
00130                 generate_vector_unary_element_ops(source, numeric_string, "tanh");
00131               }
00132               else
00133               {
00134                 generate_vector_unary_element_ops(source, numeric_string, "abs");
00135               }
00136 
00137               // binary operations
00138               generate_vector_binary_element_ops(source, numeric_string);
00139 
00140               std::string prog_name = program_name();
00141               #ifdef VIENNACL_BUILD_INFO
00142               std::cout << "Creating program " << prog_name << std::endl;
00143               #endif
00144               ctx.add_program(source, prog_name);
00145               init_done[ctx.handle().get()] = true;
00146             } //if
00147           } //init
00148         };
00149 
00150       }  // namespace kernels
00151     }  // namespace opencl
00152   }  // namespace linalg
00153 }  // namespace viennacl
00154 #endif
00155