ViennaCL - The Vienna Computing Library  1.5.2
viennacl/linalg/opencl/direct_solve.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_LINALG_OPENCL_DIRECT_SOLVE_HPP
00002 #define VIENNACL_LINALG_OPENCL_DIRECT_SOLVE_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2014, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00025 #include "viennacl/vector.hpp"
00026 #include "viennacl/matrix.hpp"
00027 #include "viennacl/ocl/kernel.hpp"
00028 #include "viennacl/ocl/device.hpp"
00029 #include "viennacl/ocl/handle.hpp"
00030 #include "viennacl/linalg/opencl/kernels/matrix_solve.hpp"
00031 
00032 namespace viennacl
00033 {
00034   namespace linalg
00035   {
00036     namespace opencl
00037     {
00038       namespace detail
00039       {
00040         inline cl_uint get_option_for_solver_tag(viennacl::linalg::upper_tag)      { return 0; }
00041         inline cl_uint get_option_for_solver_tag(viennacl::linalg::unit_upper_tag) { return (1 << 0); }
00042         inline cl_uint get_option_for_solver_tag(viennacl::linalg::lower_tag)      { return (1 << 2); }
00043         inline cl_uint get_option_for_solver_tag(viennacl::linalg::unit_lower_tag) { return (1 << 2) | (1 << 0); }
00044 
00045         template <typename M1, typename M2, typename KernelType>
00046         void inplace_solve_impl(M1 const & A, M2 & B, KernelType & k)
00047         {
00048           viennacl::ocl::enqueue(k(viennacl::traits::opencl_handle(A),
00049                                    cl_uint(viennacl::traits::start1(A)),         cl_uint(viennacl::traits::start2(A)),
00050                                    cl_uint(viennacl::traits::stride1(A)),        cl_uint(viennacl::traits::stride2(A)),
00051                                    cl_uint(viennacl::traits::size1(A)),          cl_uint(viennacl::traits::size2(A)),
00052                                    cl_uint(viennacl::traits::internal_size1(A)), cl_uint(viennacl::traits::internal_size2(A)),
00053                                    viennacl::traits::opencl_handle(B),
00054                                    cl_uint(viennacl::traits::start1(B)),         cl_uint(viennacl::traits::start2(B)),
00055                                    cl_uint(viennacl::traits::stride1(B)),        cl_uint(viennacl::traits::stride2(B)),
00056                                    cl_uint(viennacl::traits::size1(B)),          cl_uint(viennacl::traits::size2(B)),
00057                                    cl_uint(viennacl::traits::internal_size1(B)), cl_uint(viennacl::traits::internal_size2(B))
00058                                   )
00059                                 );
00060         }
00061       }
00062 
00063 
00064       //
00065       // Note: By convention, all size checks are performed in the calling frontend. No need to double-check here.
00066       //
00067 
00069 
00074       template <typename NumericT, typename F1, typename F2, typename SOLVERTAG>
00075       void inplace_solve(const matrix_base<NumericT, F1> & A, matrix_base<NumericT, F2> & B, SOLVERTAG)
00076       {
00077         viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(A).context());
00078 
00079         typedef viennacl::linalg::opencl::kernels::matrix_solve<NumericT, F1, F2>    KernelClass;
00080         KernelClass::init(ctx);
00081 
00082         std::stringstream ss;
00083         ss << SOLVERTAG::name() << "_solve";
00084         viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), ss.str());
00085 
00086         k.global_work_size(0, B.size2() * k.local_work_size());
00087         detail::inplace_solve_impl(A, B, k);
00088       }
00089 
00095       template <typename NumericT, typename F1, typename F2, typename SOLVERTAG>
00096       void inplace_solve(const matrix_base<NumericT, F1> & A,
00097                          matrix_expression< const matrix_base<NumericT, F2>, const matrix_base<NumericT, F2>, op_trans> proxy_B,
00098                          SOLVERTAG)
00099       {
00100         viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(A).context());
00101 
00102         typedef viennacl::linalg::opencl::kernels::matrix_solve<NumericT, F1, F2>    KernelClass;
00103         KernelClass::init(ctx);
00104 
00105         std::stringstream ss;
00106         ss << SOLVERTAG::name() << "_trans_solve";
00107         viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), ss.str());
00108 
00109         k.global_work_size(0, proxy_B.lhs().size1() * k.local_work_size());
00110         detail::inplace_solve_impl(A, proxy_B.lhs(), k);
00111       }
00112 
00113       //upper triangular solver for transposed lower triangular matrices
00119       template <typename NumericT, typename F1, typename F2, typename SOLVERTAG>
00120       void inplace_solve(const matrix_expression< const matrix_base<NumericT, F1>, const matrix_base<NumericT, F1>, op_trans> & proxy_A,
00121                          matrix_base<NumericT, F2> & B,
00122                          SOLVERTAG)
00123       {
00124         viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(B).context());
00125 
00126         typedef viennacl::linalg::opencl::kernels::matrix_solve<NumericT, F1, F2>    KernelClass;
00127         KernelClass::init(ctx);
00128 
00129         std::stringstream ss;
00130         ss << "trans_" << SOLVERTAG::name() << "_solve";
00131         viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), ss.str());
00132 
00133         k.global_work_size(0, B.size2() * k.local_work_size());
00134         detail::inplace_solve_impl(proxy_A.lhs(), B, k);
00135       }
00136 
00142       template <typename NumericT, typename F1, typename F2, typename SOLVERTAG>
00143       void inplace_solve(const matrix_expression< const matrix_base<NumericT, F1>, const matrix_base<NumericT, F1>, op_trans> & proxy_A,
00144                                matrix_expression< const matrix_base<NumericT, F2>, const matrix_base<NumericT, F2>, op_trans>   proxy_B,
00145                          SOLVERTAG)
00146       {
00147         viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(proxy_A.lhs()).context());
00148 
00149         typedef viennacl::linalg::opencl::kernels::matrix_solve<NumericT, F1, F2>    KernelClass;
00150         KernelClass::init(ctx);
00151 
00152         std::stringstream ss;
00153         ss << "trans_" << SOLVERTAG::name() << "_trans_solve";
00154         viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), ss.str());
00155 
00156         k.global_work_size(0, proxy_B.lhs().size1() * k.local_work_size());
00157         detail::inplace_solve_impl(proxy_A.lhs(), proxy_B.lhs(), k);
00158       }
00159 
00160 
00161 
00162       //
00163       //  Solve on vector
00164       //
00165 
00166       template <typename NumericT, typename F, typename SOLVERTAG>
00167       void inplace_solve(const matrix_base<NumericT, F> & mat,
00168                                vector_base<NumericT> & vec,
00169                          SOLVERTAG)
00170       {
00171         viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(mat).context());
00172 
00173         typedef viennacl::linalg::opencl::kernels::matrix<NumericT, F>  KernelClass;
00174         KernelClass::init(ctx);
00175 
00176         cl_uint options = detail::get_option_for_solver_tag(SOLVERTAG());
00177         viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), "triangular_substitute_inplace");
00178 
00179         k.global_work_size(0, k.local_work_size());
00180         viennacl::ocl::enqueue(k(viennacl::traits::opencl_handle(mat),
00181                                  cl_uint(viennacl::traits::start1(mat)),         cl_uint(viennacl::traits::start2(mat)),
00182                                  cl_uint(viennacl::traits::stride1(mat)),        cl_uint(viennacl::traits::stride2(mat)),
00183                                  cl_uint(viennacl::traits::size1(mat)),          cl_uint(viennacl::traits::size2(mat)),
00184                                  cl_uint(viennacl::traits::internal_size1(mat)), cl_uint(viennacl::traits::internal_size2(mat)),
00185                                  viennacl::traits::opencl_handle(vec),
00186                                  cl_uint(viennacl::traits::start(vec)),
00187                                  cl_uint(viennacl::traits::stride(vec)),
00188                                  cl_uint(viennacl::traits::size(vec)),
00189                                  options
00190                                 )
00191                               );
00192       }
00193 
00199       template <typename NumericT, typename F, typename SOLVERTAG>
00200       void inplace_solve(const matrix_expression< const matrix_base<NumericT, F>, const matrix_base<NumericT, F>, op_trans> & proxy,
00201                          vector_base<NumericT> & vec,
00202                          SOLVERTAG)
00203       {
00204         viennacl::ocl::context & ctx = const_cast<viennacl::ocl::context &>(viennacl::traits::opencl_handle(vec).context());
00205 
00206         typedef viennacl::linalg::opencl::kernels::matrix<NumericT, F>  KernelClass;
00207         KernelClass::init(ctx);
00208 
00209         cl_uint options = detail::get_option_for_solver_tag(SOLVERTAG()) | 0x02;  //add transpose-flag
00210         viennacl::ocl::kernel & k = ctx.get_kernel(KernelClass::program_name(), "triangular_substitute_inplace");
00211 
00212         k.global_work_size(0, k.local_work_size());
00213         viennacl::ocl::enqueue(k(viennacl::traits::opencl_handle(proxy.lhs()),
00214                                  cl_uint(viennacl::traits::start1(proxy.lhs())),         cl_uint(viennacl::traits::start2(proxy.lhs())),
00215                                  cl_uint(viennacl::traits::stride1(proxy.lhs())),        cl_uint(viennacl::traits::stride2(proxy.lhs())),
00216                                  cl_uint(viennacl::traits::size1(proxy.lhs())),          cl_uint(viennacl::traits::size2(proxy.lhs())),
00217                                  cl_uint(viennacl::traits::internal_size1(proxy.lhs())), cl_uint(viennacl::traits::internal_size2(proxy.lhs())),
00218                                  viennacl::traits::opencl_handle(vec),
00219                                  cl_uint(viennacl::traits::start(vec)),
00220                                  cl_uint(viennacl::traits::stride(vec)),
00221                                  cl_uint(viennacl::traits::size(vec)),
00222                                  options
00223                                 )
00224                               );
00225       }
00226 
00227 
00228     }
00229   }
00230 }
00231 
00232 #endif