ViennaCL - The Vienna Computing Library  1.5.2
viennacl/tools/matrix_size_deducer.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_
00002 #define VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2014, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00025 #include <string>
00026 #include <fstream>
00027 #include <sstream>
00028 #include <cmath>
00029 #include <vector>
00030 #include <map>
00031 
00032 #include "viennacl/forwards.h"
00033 #include "viennacl/tools/adapter.hpp"
00034 
00035 namespace viennacl
00036 {
00037   namespace tools
00038   {
00039 
00046     template <typename LHS, typename RHS, typename OP>
00047     struct MATRIX_SIZE_DEDUCER
00048     {
00049       //Standard case: size1 from lhs, size2 from rhs (fits most cases)
00050       static vcl_size_t size1(LHS & lhs, RHS & /*rhs*/) { return lhs.size1(); }
00051       static vcl_size_t size2(LHS & /*lhs*/, RHS & rhs) { return rhs.size2(); }
00052     };
00053 
00055     //special case: outer vector product:
00056     template <typename ScalarType>
00057     struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<ScalarType>,
00058                                const viennacl::vector_base<ScalarType>,
00059                                viennacl::op_prod>
00060     {
00061       static vcl_size_t size1(viennacl::vector_base<ScalarType> const & lhs,
00062                                viennacl::vector_base<ScalarType> const & /*rhs*/) { return lhs.size(); }
00063 
00064       static vcl_size_t size2(viennacl::vector_base<ScalarType> const & /*lhs*/,
00065                                viennacl::vector_base<ScalarType> const & rhs) { return rhs.size(); }
00066     };
00067 
00068 
00069     //special case: multiplication with a scalar
00070     template <typename LHS, typename RHS, typename OP, typename ScalarType>
00071     struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>,
00072                                const ScalarType,
00073                                viennacl::op_mult>
00074     {
00075       static vcl_size_t size1(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs,
00076                                ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); }
00077 
00078       static vcl_size_t size2(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs,
00079                                ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); }
00080     };
00081 
00082     //special case: multiplication with a scalar
00083     template <typename T, typename F, typename ScalarType>
00084     struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T, F>,
00085                                const ScalarType,
00086                                viennacl::op_mult>
00087     {
00088       static vcl_size_t size1(viennacl::matrix_base<T, F> const & lhs,
00089                                ScalarType const & /*rhs*/) { return lhs.size1(); }
00090 
00091       static vcl_size_t size2(viennacl::matrix_base<T, F> const & lhs,
00092                                ScalarType const & /*rhs*/) { return lhs.size2(); }
00093     };
00094 
00095 
00096     //special case: division with a scalar
00097     template <typename LHS, typename RHS, typename OP, typename ScalarType>
00098     struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>,
00099                                const ScalarType,
00100                                viennacl::op_div>
00101     {
00102       static vcl_size_t size1(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs,
00103                                ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); }
00104 
00105       static vcl_size_t size2(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs,
00106                                ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); }
00107     };
00108 
00109     //special case: division with a scalar
00110     template <typename T, typename F, typename ScalarType>
00111     struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T, F>,
00112                                const ScalarType,
00113                                viennacl::op_div>
00114     {
00115       static vcl_size_t size1(viennacl::matrix_base<T, F> const & lhs,
00116                                ScalarType const & /*rhs*/) { return lhs.size1(); }
00117 
00118       static vcl_size_t size2(viennacl::matrix_base<T, F> const & lhs,
00119                                ScalarType const & /*rhs*/) { return lhs.size2(); }
00120     };
00121 
00122     //special case: diagonal from vector
00123     template <typename T>
00124     struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<T>,
00125                                const int,
00126                                viennacl::op_vector_diag>
00127     {
00128       static vcl_size_t size1(viennacl::vector_base<T> const & lhs,
00129                                const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); }
00130 
00131       static vcl_size_t size2(viennacl::vector_base<T> const & lhs,
00132                                const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); }
00133     };
00134 
00135 
00136 
00137 
00138 
00139 
00140 
00141 
00142     //special case: transposed matrix-vector product: Return the number of rows of the matrix
00143     template <typename MatrixType>
00144     struct MATRIX_SIZE_DEDUCER<MatrixType,
00145                                MatrixType,
00146                                viennacl::op_trans>
00147     {
00148       static vcl_size_t size1(const MatrixType & lhs,
00149                                const MatrixType & /*rhs*/) { return lhs.size2(); }
00150       static vcl_size_t size2(const MatrixType & lhs,
00151                                const MatrixType & /*rhs*/) { return lhs.size1(); }
00152     };
00153 
00154     // A^T * B
00155     template <typename ScalarType, typename T1, typename F2>
00156     struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1,
00157                                                                  T1, op_trans>,
00158                                const viennacl::matrix_base<ScalarType, F2>,
00159                                viennacl::op_mat_mat_prod>
00160     {
00161       static vcl_size_t size1(viennacl::matrix_expression<T1,
00162                                                            T1,
00163                                                            op_trans> const & lhs,
00164                                viennacl::matrix_base<ScalarType, F2> const & /*rhs*/) { return lhs.lhs().size2(); }
00165       static vcl_size_t size2(viennacl::matrix_expression<T1,
00166                                                            T1,
00167                                                            op_trans> const & /*lhs*/,
00168                                viennacl::matrix_base<ScalarType, F2> const & rhs) { return rhs.size2(); }
00169     };
00170 
00171 
00172     // A * B^T
00173 
00174     template <typename ScalarType, typename F1, typename T2>
00175     struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<ScalarType, F1>,
00176                                const viennacl::matrix_expression<T2,
00177                                                                  T2, op_trans>,
00178                                viennacl::op_mat_mat_prod>
00179     {
00180       static vcl_size_t size1(viennacl::matrix_base<ScalarType, F1> const & lhs,
00181                                viennacl::matrix_expression<T2,
00182                                                            T2,
00183                                                            op_trans> const & /*rhs*/) { return lhs.size1(); }
00184       static vcl_size_t size2(viennacl::matrix_base<ScalarType, F1> const & /*lhs*/,
00185                                viennacl::matrix_expression<T2,
00186                                                            T2,
00187                                                            op_trans> const & rhs) { return rhs.lhs().size1(); }
00188     };
00189 
00190 
00191 
00192 
00193     // A^T * B^T
00194 
00195     template <typename T1, typename T2>
00196     struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1,
00197                                                                  T1, op_trans>,
00198                                const viennacl::matrix_expression<T2,
00199                                                                  T2, op_trans>,
00200                                viennacl::op_mat_mat_prod>
00201     {
00202       typedef viennacl::matrix_expression<T1, T1, op_trans>   LHSType;
00203       typedef viennacl::matrix_expression<T2, T2, op_trans>   RHSType;
00204 
00205       static vcl_size_t size1(LHSType const & lhs,
00206                                RHSType const & /*rhs*/) { return lhs.lhs().size2(); }
00207       static vcl_size_t size2(LHSType const & /*lhs*/,
00208                                RHSType const & rhs) { return rhs.lhs().size1(); }
00209     };
00211   }
00212 }
00213 
00214 #endif
00215