ViennaCL - The Vienna Computing Library  1.5.2
viennacl/scheduler/execute.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_SCHEDULER_EXECUTE_HPP
00002 #define VIENNACL_SCHEDULER_EXECUTE_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2014, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include "viennacl/forwards.h"
00027 #include "viennacl/scheduler/forwards.h"
00028 
00029 #include "viennacl/scheduler/execute_scalar_assign.hpp"
00030 #include "viennacl/scheduler/execute_axbx.hpp"
00031 #include "viennacl/scheduler/execute_elementwise.hpp"
00032 #include "viennacl/scheduler/execute_matrix_prod.hpp"
00033 
00034 namespace viennacl
00035 {
00036   namespace scheduler
00037   {
00038     namespace detail
00039     {
00041       void execute_composite(statement const & s, statement_node const & root_node)
00042       {
00043         statement::container_type const & expr = s.array();
00044 
00045         statement_node const & leaf = expr[root_node.rhs.node_index];
00046 
00047         if (leaf.op.type  == OPERATION_BINARY_ADD_TYPE || leaf.op.type  == OPERATION_BINARY_SUB_TYPE) // x = (y) +- (z)  where y and z are either data objects or expressions
00048         {
00049           execute_axbx(s, root_node);
00050         }
00051         else if (leaf.op.type == OPERATION_BINARY_MULT_TYPE || leaf.op.type == OPERATION_BINARY_DIV_TYPE) // x = (y) * / alpha;
00052         {
00053           bool scalar_is_temporary = (leaf.rhs.type_family != SCALAR_TYPE_FAMILY);
00054 
00055           statement_node scalar_temp_node;
00056           if (scalar_is_temporary)
00057           {
00058             lhs_rhs_element temp;
00059             temp.type_family  = SCALAR_TYPE_FAMILY;
00060             temp.subtype      = DEVICE_SCALAR_TYPE;
00061             temp.numeric_type = root_node.lhs.numeric_type;
00062             detail::new_element(scalar_temp_node.lhs, temp);
00063 
00064             scalar_temp_node.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00065             scalar_temp_node.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00066 
00067             scalar_temp_node.rhs.type_family  = COMPOSITE_OPERATION_FAMILY;
00068             scalar_temp_node.rhs.subtype      = INVALID_SUBTYPE;
00069             scalar_temp_node.rhs.numeric_type = INVALID_NUMERIC_TYPE;
00070             scalar_temp_node.rhs.node_index   = leaf.rhs.node_index;
00071 
00072             // work on subexpression:
00073             // TODO: Catch exception, free temporary, then rethrow
00074             execute_composite(s, scalar_temp_node);
00075           }
00076 
00077           if (leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY)  //(y) is an expression, so introduce a temporary z = (y):
00078           {
00079             statement_node new_root_y;
00080 
00081             new_root_y.lhs.type_family  = root_node.lhs.type_family;
00082             new_root_y.lhs.subtype      = root_node.lhs.subtype;
00083             new_root_y.lhs.numeric_type = root_node.lhs.numeric_type;
00084             detail::new_element(new_root_y.lhs, root_node.lhs);
00085 
00086             new_root_y.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
00087             new_root_y.op.type        = OPERATION_BINARY_ASSIGN_TYPE;
00088 
00089             new_root_y.rhs.type_family  = COMPOSITE_OPERATION_FAMILY;
00090             new_root_y.rhs.subtype      = INVALID_SUBTYPE;
00091             new_root_y.rhs.numeric_type = INVALID_NUMERIC_TYPE;
00092             new_root_y.rhs.node_index   = leaf.lhs.node_index;
00093 
00094             // work on subexpression:
00095             // TODO: Catch exception, free temporary, then rethrow
00096             execute_composite(s, new_root_y);
00097 
00098             // now compute x = z * / alpha:
00099             lhs_rhs_element u = root_node.lhs;
00100             lhs_rhs_element v = new_root_y.lhs;
00101             lhs_rhs_element alpha = scalar_is_temporary ? scalar_temp_node.lhs : leaf.rhs;
00102 
00103             bool is_division = (leaf.op.type  == OPERATION_BINARY_DIV_TYPE);
00104             switch (root_node.op.type)
00105             {
00106               case OPERATION_BINARY_ASSIGN_TYPE:
00107                 detail::ax(u,
00108                            v, alpha, 1, is_division, false);
00109                 break;
00110               case OPERATION_BINARY_INPLACE_ADD_TYPE:
00111                 detail::axbx(u,
00112                              u,   1.0, 1, false,       false,
00113                              v, alpha, 1, is_division, false);
00114                 break;
00115               case OPERATION_BINARY_INPLACE_SUB_TYPE:
00116                 detail::axbx(u,
00117                              u,   1.0, 1, false,       false,
00118                              v, alpha, 1, is_division, true);
00119                 break;
00120               default:
00121                 throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00122             }
00123 
00124             detail::delete_element(new_root_y.lhs);
00125           }
00126           else if (leaf.lhs.type_family != COMPOSITE_OPERATION_FAMILY)
00127           {
00128             lhs_rhs_element u = root_node.lhs;
00129             lhs_rhs_element v = leaf.lhs;
00130             lhs_rhs_element alpha = scalar_is_temporary ? scalar_temp_node.lhs : leaf.rhs;
00131 
00132             bool is_division = (leaf.op.type  == OPERATION_BINARY_DIV_TYPE);
00133             switch (root_node.op.type)
00134             {
00135               case OPERATION_BINARY_ASSIGN_TYPE:
00136                 detail::ax(u,
00137                            v, alpha, 1, is_division, false);
00138                 break;
00139               case OPERATION_BINARY_INPLACE_ADD_TYPE:
00140                 detail::axbx(u,
00141                              u,   1.0, 1, false,       false,
00142                              v, alpha, 1, is_division, false);
00143                 break;
00144               case OPERATION_BINARY_INPLACE_SUB_TYPE:
00145                 detail::axbx(u,
00146                              u,   1.0, 1, false,       false,
00147                              v, alpha, 1, is_division, true);
00148                 break;
00149               default:
00150                 throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
00151             }
00152           }
00153           else
00154             throw statement_not_supported_exception("Unsupported binary operator for OPERATION_BINARY_MULT_TYPE || OPERATION_BINARY_DIV_TYPE on leaf node.");
00155 
00156           // clean up
00157           if (scalar_is_temporary)
00158             detail::delete_element(scalar_temp_node.lhs);
00159         }
00160         else if (   leaf.op.type == OPERATION_BINARY_INNER_PROD_TYPE
00161                  || leaf.op.type == OPERATION_UNARY_NORM_1_TYPE
00162                  || leaf.op.type == OPERATION_UNARY_NORM_2_TYPE
00163                  || leaf.op.type == OPERATION_UNARY_NORM_INF_TYPE)
00164         {
00165           execute_scalar_assign_composite(s, root_node);
00166         }
00167         else if (   (leaf.op.type_family == OPERATION_UNARY_TYPE_FAMILY && leaf.op.type != OPERATION_UNARY_TRANS_TYPE)
00168                  || leaf.op.type == OPERATION_BINARY_ELEMENT_PROD_TYPE
00169                  || leaf.op.type == OPERATION_BINARY_ELEMENT_DIV_TYPE) // element-wise operations
00170         {
00171           execute_element_composite(s, root_node);
00172         }
00173         else if (   leaf.op.type == OPERATION_BINARY_MAT_VEC_PROD_TYPE
00174                  || leaf.op.type == OPERATION_BINARY_MAT_MAT_PROD_TYPE)
00175         {
00176           execute_matrix_prod(s, root_node);
00177         }
00178         else if (   leaf.op.type == OPERATION_UNARY_TRANS_TYPE)
00179         {
00180           assign_trans(root_node.lhs, leaf.lhs);
00181         }
00182         else
00183           throw statement_not_supported_exception("Unsupported binary operator");
00184       }
00185 
00186 
00188       inline void execute_single(statement const &, statement_node const & root_node)
00189       {
00190         lhs_rhs_element u = root_node.lhs;
00191         lhs_rhs_element v = root_node.rhs;
00192         switch (root_node.op.type)
00193         {
00194           case OPERATION_BINARY_ASSIGN_TYPE:
00195             detail::ax(u,
00196                        v, 1.0, 1, false, false);
00197             break;
00198           case OPERATION_BINARY_INPLACE_ADD_TYPE:
00199             detail::axbx(u,
00200                          u, 1.0, 1, false, false,
00201                          v, 1.0, 1, false, false);
00202             break;
00203           case OPERATION_BINARY_INPLACE_SUB_TYPE:
00204             detail::axbx(u,
00205                          u, 1.0, 1, false, false,
00206                          v, 1.0, 1, false, true);
00207             break;
00208           default:
00209             throw statement_not_supported_exception("Unsupported binary operator for operation in root note (should be =, +=, or -=)");
00210         }
00211 
00212       }
00213 
00214 
00215       inline void execute_impl(statement const & s, statement_node const & root_node)
00216       {
00217         if (   root_node.lhs.type_family != SCALAR_TYPE_FAMILY
00218             && root_node.lhs.type_family != VECTOR_TYPE_FAMILY
00219             && root_node.lhs.type_family != MATRIX_TYPE_FAMILY)
00220           throw statement_not_supported_exception("Unsupported lvalue encountered in head node.");
00221 
00222         switch (root_node.rhs.type_family)
00223         {
00224           case COMPOSITE_OPERATION_FAMILY:
00225             execute_composite(s, root_node);
00226             break;
00227           case SCALAR_TYPE_FAMILY:
00228           case VECTOR_TYPE_FAMILY:
00229           case MATRIX_TYPE_FAMILY:
00230             execute_single(s, root_node);
00231             break;
00232           default:
00233             throw statement_not_supported_exception("Invalid rvalue encountered in vector assignment");
00234         }
00235 
00236       }
00237     }
00238 
00239     inline void execute(statement const & s)
00240     {
00241       // simply start execution from the root node:
00242       detail::execute_impl(s, s.array()[s.root()]);
00243     }
00244 
00245 
00246   }
00247 
00248 } //namespace viennacl
00249 
00250 #endif
00251