IT++ Logo
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
mat.cpp
Go to the documentation of this file.
1 
29 #include <itpp/base/mat.h>
30 
31 #ifndef _MSC_VER
32 # include <itpp/config.h>
33 #else
34 # include <itpp/config_msvc.h>
35 #endif
36 
37 #if defined (HAVE_BLAS)
38 # include <itpp/base/blas.h>
39 #endif
40 
42 
43 namespace itpp
44 {
45 
46 template<>
47 cmat cmat::hermitian_transpose() const
48 {
49  cmat temp(no_cols, no_rows);
50  for (int i = 0; i < no_rows; i++)
51  for (int j = 0; j < no_cols; j++)
52  temp(j, i) = std::conj(operator()(i,j));
53 
54  return temp;
55 }
56 
57 
58 // -------- Multiplication operator -------------
59 
60 #if defined(HAVE_BLAS)
61 
62 template<>
63 mat& mat::operator*=(const mat &m)
64 {
65  it_assert_debug(no_cols == m.no_rows, "mat::operator*=(): Wrong sizes");
66  mat r(no_rows, m.no_cols); // unnecessary memory??
67  double alpha = 1.0;
68  double beta = 0.0;
69  char trans = 'n';
70  blas::dgemm_(&trans, &trans, &no_rows, &m.no_cols, &no_cols, &alpha, data,
71  &no_rows, m.data, &m.no_rows, &beta, r.data, &r.no_rows);
72  operator=(r); // time consuming
73  return *this;
74 }
75 
76 template<>
77 cmat& cmat::operator*=(const cmat &m)
78 {
79  it_assert_debug(no_cols == m.no_rows, "cmat::operator*=(): Wrong sizes");
80  cmat r(no_rows, m.no_cols); // unnecessary memory??
81  std::complex<double> alpha = std::complex<double>(1.0);
82  std::complex<double> beta = std::complex<double>(0.0);
83  char trans = 'n';
84  blas::zgemm_(&trans, &trans, &no_rows, &m.no_cols, &no_cols, &alpha, data,
85  &no_rows, m.data, &m.no_rows, &beta, r.data, &r.no_rows);
86  operator=(r); // time consuming
87  return *this;
88 }
89 #else
90 template<>
91 mat& mat::operator*=(const mat &m)
92 {
93  it_assert_debug(no_cols == m.no_rows, "Mat<>::operator*=(): Wrong sizes");
94  mat r(no_rows, m.no_cols);
95  int r_pos = 0, pos = 0, m_pos = 0;
96 
97  for (int i = 0; i < r.no_cols; i++) {
98  for (int j = 0; j < r.no_rows; j++) {
99  double tmp = 0.0;
100  pos = 0;
101  for (int k = 0; k < no_cols; k++) {
102  tmp += data[pos+j] * m.data[m_pos+k];
103  pos += no_rows;
104  }
105  r.data[r_pos+j] = tmp;
106  }
107  r_pos += r.no_rows;
108  m_pos += m.no_rows;
109  }
110  operator=(r); // time consuming
111  return *this;
112 }
113 
114 template<>
115 cmat& cmat::operator*=(const cmat &m)
116 {
117  it_assert_debug(no_cols == m.no_rows, "Mat<>::operator*=(): Wrong sizes");
118  cmat r(no_rows, m.no_cols);
119  int r_pos = 0, pos = 0, m_pos = 0;
120 
121  for (int i = 0; i < r.no_cols; i++) {
122  for (int j = 0; j < r.no_rows; j++) {
123  std::complex<double> tmp(0.0);
124  pos = 0;
125  for (int k = 0; k < no_cols; k++) {
126  tmp += data[pos+j] * m.data[m_pos+k];
127  pos += no_rows;
128  }
129  r.data[r_pos+j] = tmp;
130  }
131  r_pos += r.no_rows;
132  m_pos += m.no_rows;
133  }
134  operator=(r); // time consuming
135  return *this;
136 }
137 #endif // HAVE_BLAS
138 
139 
140 #if defined(HAVE_BLAS)
141 template<>
142 mat operator*(const mat &m1, const mat &m2)
143 {
144  it_assert_debug(m1.no_cols == m2.no_rows, "mat::operator*(): Wrong sizes");
145  mat r(m1.no_rows, m2.no_cols);
146  double alpha = 1.0;
147  double beta = 0.0;
148  char trans = 'n';
149  blas::dgemm_(&trans, &trans, &m1.no_rows, &m2.no_cols, &m1.no_cols, &alpha,
150  m1.data, &m1.no_rows, m2.data, &m2.no_rows, &beta, r.data,
151  &r.no_rows);
152  return r;
153 }
154 
155 template<>
156 cmat operator*(const cmat &m1, const cmat &m2)
157 {
158  it_assert_debug(m1.no_cols == m2.no_rows, "cmat::operator*(): Wrong sizes");
159  cmat r(m1.no_rows, m2.no_cols);
160  std::complex<double> alpha = std::complex<double>(1.0);
161  std::complex<double> beta = std::complex<double>(0.0);
162  char trans = 'n';
163  blas::zgemm_(&trans, &trans, &m1.no_rows, &m2.no_cols, &m1.no_cols, &alpha,
164  m1.data, &m1.no_rows, m2.data, &m2.no_rows, &beta, r.data,
165  &r.no_rows);
166  return r;
167 }
168 #else
169 template<>
170 mat operator*(const mat &m1, const mat &m2)
171 {
172  it_assert_debug(m1.no_cols == m2.no_rows,
173  "Mat<>::operator*(): Wrong sizes");
174  mat r(m1.no_rows, m2.no_cols);
175  double *tr = r.data;
176  double *t1;
177  double *t2 = m2.data;
178  for (int i = 0; i < r.no_cols; i++) {
179  for (int j = 0; j < r.no_rows; j++) {
180  double tmp = 0.0;
181  t1 = m1.data + j;
182  for (int k = m1.no_cols; k > 0; k--) {
183  tmp += *(t1) * *(t2++);
184  t1 += m1.no_rows;
185  }
186  *(tr++) = tmp;
187  t2 -= m2.no_rows;
188  }
189  t2 += m2.no_rows;
190  }
191  return r;
192 }
193 
194 template<>
195 cmat operator*(const cmat &m1, const cmat &m2)
196 {
197  it_assert_debug(m1.no_cols == m2.no_rows,
198  "Mat<>::operator*(): Wrong sizes");
199  cmat r(m1.no_rows, m2.no_cols);
200  std::complex<double> *tr = r.data;
201  std::complex<double> *t1;
202  std::complex<double> *t2 = m2.data;
203  for (int i = 0; i < r.no_cols; i++) {
204  for (int j = 0; j < r.no_rows; j++) {
205  std::complex<double> tmp(0.0);
206  t1 = m1.data + j;
207  for (int k = m1.no_cols; k > 0; k--) {
208  tmp += *(t1) * *(t2++);
209  t1 += m1.no_rows;
210  }
211  *(tr++) = tmp;
212  t2 -= m2.no_rows;
213  }
214  t2 += m2.no_rows;
215  }
216  return r;
217 }
218 #endif // HAVE_BLAS
219 
220 
221 #if defined(HAVE_BLAS)
222 template<>
223 vec operator*(const mat &m, const vec &v)
224 {
225  it_assert_debug(m.no_cols == v.size(), "mat::operator*(): Wrong sizes");
226  vec r(m.no_rows);
227  double alpha = 1.0;
228  double beta = 0.0;
229  char trans = 'n';
230  int incr = 1;
231  blas::dgemv_(&trans, &m.no_rows, &m.no_cols, &alpha, m.data, &m.no_rows,
232  v._data(), &incr, &beta, r._data(), &incr);
233  return r;
234 }
235 
236 template<>
237 cvec operator*(const cmat &m, const cvec &v)
238 {
239  it_assert_debug(m.no_cols == v.size(), "cmat::operator*(): Wrong sizes");
240  cvec r(m.no_rows);
241  std::complex<double> alpha = std::complex<double>(1.0);
242  std::complex<double> beta = std::complex<double>(0.0);
243  char trans = 'n';
244  int incr = 1;
245  blas::zgemv_(&trans, &m.no_rows, &m.no_cols, &alpha, m.data, &m.no_rows,
246  v._data(), &incr, &beta, r._data(), &incr);
247  return r;
248 }
249 #else
250 template<>
251 vec operator*(const mat &m, const vec &v)
252 {
253  it_assert_debug(m.no_cols == v.size(),
254  "Mat<>::operator*(): Wrong sizes");
255  vec r(m.no_rows);
256  for (int i = 0; i < m.no_rows; i++) {
257  r(i) = 0.0;
258  int m_pos = 0;
259  for (int k = 0; k < m.no_cols; k++) {
260  r(i) += m.data[m_pos+i] * v(k);
261  m_pos += m.no_rows;
262  }
263  }
264  return r;
265 }
266 
267 template<>
268 cvec operator*(const cmat &m, const cvec &v)
269 {
270  it_assert_debug(m.no_cols == v.size(),
271  "Mat<>::operator*(): Wrong sizes");
272  cvec r(m.no_rows);
273  for (int i = 0; i < m.no_rows; i++) {
274  r(i) = std::complex<double>(0.0);
275  int m_pos = 0;
276  for (int k = 0; k < m.no_cols; k++) {
277  r(i) += m.data[m_pos+i] * v(k);
278  m_pos += m.no_rows;
279  }
280  }
281  return r;
282 }
283 #endif // HAVE_BLAS
284 
285 
286 //---------------------------------------------------------------------
287 // Instantiations
288 //---------------------------------------------------------------------
289 
290 // class instantiations
291 
292 template class Mat<double>;
293 template class Mat<std::complex<double> >;
294 template class Mat<int>;
295 template class Mat<short int>;
296 template class Mat<bin>;
297 
298 // addition operators
299 
300 template mat operator+(const mat &m1, const mat &m2);
301 template cmat operator+(const cmat &m1, const cmat &m2);
302 template imat operator+(const imat &m1, const imat &m2);
303 template smat operator+(const smat &m1, const smat &m2);
304 template bmat operator+(const bmat &m1, const bmat &m2);
305 
306 template mat operator+(const mat &m, double t);
307 template cmat operator+(const cmat &m, std::complex<double> t);
308 template imat operator+(const imat &m, int t);
309 template smat operator+(const smat &m, short t);
310 template bmat operator+(const bmat &m, bin t);
311 
312 template mat operator+(double t, const mat &m);
313 template cmat operator+(std::complex<double> t, const cmat &m);
314 template imat operator+(int t, const imat &m);
315 template smat operator+(short t, const smat &m);
316 template bmat operator+(bin t, const bmat &m);
317 
318 // subraction operators
319 
320 template mat operator-(const mat &m1, const mat &m2);
321 template cmat operator-(const cmat &m1, const cmat &m2);
322 template imat operator-(const imat &m1, const imat &m2);
323 template smat operator-(const smat &m1, const smat &m2);
324 template bmat operator-(const bmat &m1, const bmat &m2);
325 
326 template mat operator-(const mat &m, double t);
327 template cmat operator-(const cmat &m, std::complex<double> t);
328 template imat operator-(const imat &m, int t);
329 template smat operator-(const smat &m, short t);
330 template bmat operator-(const bmat &m, bin t);
331 
332 template mat operator-(double t, const mat &m);
333 template cmat operator-(std::complex<double> t, const cmat &m);
334 template imat operator-(int t, const imat &m);
335 template smat operator-(short t, const smat &m);
336 template bmat operator-(bin t, const bmat &m);
337 
338 // unary minus
339 
340 template mat operator-(const mat &m);
341 template cmat operator-(const cmat &m);
342 template imat operator-(const imat &m);
343 template smat operator-(const smat &m);
344 template bmat operator-(const bmat &m);
345 
346 // multiplication operators
347 
348 template imat operator*(const imat &m1, const imat &m2);
349 template smat operator*(const smat &m1, const smat &m2);
350 template bmat operator*(const bmat &m1, const bmat &m2);
351 
352 template ivec operator*(const imat &m, const ivec &v);
353 template svec operator*(const smat &m, const svec &v);
354 template bvec operator*(const bmat &m, const bvec &v);
355 
356 template mat operator*(const mat &m, double t);
357 template cmat operator*(const cmat &m, std::complex<double> t);
358 template imat operator*(const imat &m, int t);
359 template smat operator*(const smat &m, short t);
360 template bmat operator*(const bmat &m, bin t);
361 
362 template mat operator*(double t, const mat &m);
363 template cmat operator*(std::complex<double> t, const cmat &m);
364 template imat operator*(int t, const imat &m);
365 template smat operator*(short t, const smat &m);
366 template bmat operator*(bin t, const bmat &m);
367 
368 // elementwise multiplication
369 
370 template mat elem_mult(const mat &m1, const mat &m2);
371 template cmat elem_mult(const cmat &m1, const cmat &m2);
372 template imat elem_mult(const imat &m1, const imat &m2);
373 template smat elem_mult(const smat &m1, const smat &m2);
374 template bmat elem_mult(const bmat &m1, const bmat &m2);
375 
376 template void elem_mult_out(const mat &m1, const mat &m2, mat &out);
377 template void elem_mult_out(const cmat &m1, const cmat &m2, cmat &out);
378 template void elem_mult_out(const imat &m1, const imat &m2, imat &out);
379 template void elem_mult_out(const smat &m1, const smat &m2, smat &out);
380 template void elem_mult_out(const bmat &m1, const bmat &m2, bmat &out);
381 
382 template void elem_mult_out(const mat &m1, const mat &m2,
383  const mat &m3, mat &out);
384 template void elem_mult_out(const cmat &m1, const cmat &m2,
385  const cmat &m3, cmat &out);
386 template void elem_mult_out(const imat &m1, const imat &m2,
387  const imat &m3, imat &out);
388 template void elem_mult_out(const smat &m1, const smat &m2,
389  const smat &m3, smat &out);
390 template void elem_mult_out(const bmat &m1, const bmat &m2,
391  const bmat &m3, bmat &out);
392 
393 template void elem_mult_out(const mat &m1, const mat &m2, const mat &m3,
394  const mat &m4, mat &out);
395 template void elem_mult_out(const cmat &m1, const cmat &m2,
396  const cmat &m3, const cmat &m4, cmat &out);
397 template void elem_mult_out(const imat &m1, const imat &m2,
398  const imat &m3, const imat &m4, imat &out);
399 template void elem_mult_out(const smat &m1, const smat &m2,
400  const smat &m3, const smat &m4, smat &out);
401 template void elem_mult_out(const bmat &m1, const bmat &m2,
402  const bmat &m3, const bmat &m4, bmat &out);
403 
404 template void elem_mult_inplace(const mat &m1, mat &m2);
405 template void elem_mult_inplace(const cmat &m1, cmat &m2);
406 template void elem_mult_inplace(const imat &m1, imat &m2);
407 template void elem_mult_inplace(const smat &m1, smat &m2);
408 template void elem_mult_inplace(const bmat &m1, bmat &m2);
409 
410 template double elem_mult_sum(const mat &m1, const mat &m2);
411 template std::complex<double> elem_mult_sum(const cmat &m1, const cmat &m2);
412 template int elem_mult_sum(const imat &m1, const imat &m2);
413 template short elem_mult_sum(const smat &m1, const smat &m2);
414 template bin elem_mult_sum(const bmat &m1, const bmat &m2);
415 
416 // division operator
417 
418 template mat operator/(double t, const mat &m);
419 template cmat operator/(std::complex<double> t, const cmat &m);
420 template imat operator/(int t, const imat &m);
421 template smat operator/(short t, const smat &m);
422 template bmat operator/(bin t, const bmat &m);
423 
424 template mat operator/(const mat &m, double t);
425 template cmat operator/(const cmat &m, std::complex<double> t);
426 template imat operator/(const imat &m, int t);
427 template smat operator/(const smat &m, short t);
428 template bmat operator/(const bmat &m, bin t);
429 
430 // elementwise division
431 
432 template mat elem_div(const mat &m1, const mat &m2);
433 template cmat elem_div(const cmat &m1, const cmat &m2);
434 template imat elem_div(const imat &m1, const imat &m2);
435 template smat elem_div(const smat &m1, const smat &m2);
436 template bmat elem_div(const bmat &m1, const bmat &m2);
437 
438 template void elem_div_out(const mat &m1, const mat &m2, mat &out);
439 template void elem_div_out(const cmat &m1, const cmat &m2, cmat &out);
440 template void elem_div_out(const imat &m1, const imat &m2, imat &out);
441 template void elem_div_out(const smat &m1, const smat &m2, smat &out);
442 template void elem_div_out(const bmat &m1, const bmat &m2, bmat &out);
443 
444 template double elem_div_sum(const mat &m1, const mat &m2);
445 template std::complex<double> elem_div_sum(const cmat &m1,
446  const cmat &m2);
447 template int elem_div_sum(const imat &m1, const imat &m2);
448 template short elem_div_sum(const smat &m1, const smat &m2);
449 template bin elem_div_sum(const bmat &m1, const bmat &m2);
450 
451 // concatenation
452 
453 template mat concat_horizontal(const mat &m1, const mat &m2);
454 template cmat concat_horizontal(const cmat &m1, const cmat &m2);
455 template imat concat_horizontal(const imat &m1, const imat &m2);
456 template smat concat_horizontal(const smat &m1, const smat &m2);
457 template bmat concat_horizontal(const bmat &m1, const bmat &m2);
458 
459 template mat concat_vertical(const mat &m1, const mat &m2);
460 template cmat concat_vertical(const cmat &m1, const cmat &m2);
461 template imat concat_vertical(const imat &m1, const imat &m2);
462 template smat concat_vertical(const smat &m1, const smat &m2);
463 template bmat concat_vertical(const bmat &m1, const bmat &m2);
464 
465 // I/O streams
466 
467 template std::ostream &operator<<(std::ostream &os, const mat &m);
468 template std::ostream &operator<<(std::ostream &os, const cmat &m);
469 template std::ostream &operator<<(std::ostream &os, const imat &m);
470 template std::ostream &operator<<(std::ostream &os, const smat &m);
471 template std::ostream &operator<<(std::ostream &os, const bmat &m);
472 
473 template std::istream &operator>>(std::istream &is, mat &m);
474 template std::istream &operator>>(std::istream &is, cmat &m);
475 template std::istream &operator>>(std::istream &is, imat &m);
476 template std::istream &operator>>(std::istream &is, smat &m);
477 template std::istream &operator>>(std::istream &is, bmat &m);
478 
479 } // namespace itpp
480 
SourceForge Logo

Generated on Fri Mar 21 2014 17:14:12 for IT++ by Doxygen 1.8.1.2