fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 136 kB View raw
1/* -*- c -*- */ 2 3/* 4 ***************************************************************************** 5 ** INCLUDES ** 6 ***************************************************************************** 7 */ 8#define PY_SSIZE_T_CLEAN 9#include <Python.h> 10 11#define NPY_NO_DEPRECATED_API NPY_API_VERSION 12#include "numpy/arrayobject.h" 13#include "numpy/ufuncobject.h" 14#include "numpy/npy_math.h" 15 16#include "npy_pycompat.h" 17 18#include "npy_config.h" 19 20#include "npy_cblas.h" 21 22#include <cstddef> 23#include <cstdio> 24#include <cassert> 25#include <cmath> 26#include <type_traits> 27#include <utility> 28 29 30static const char* umath_linalg_version_string = "0.1.5"; 31 32/* 33 **************************************************************************** 34 * Debugging support * 35 **************************************************************************** 36 */ 37#define TRACE_TXT(...) do { fprintf (stderr, __VA_ARGS__); } while (0) 38#define STACK_TRACE do {} while (0) 39#define TRACE\ 40 do { \ 41 fprintf (stderr, \ 42 "%s:%d:%s\n", \ 43 __FILE__, \ 44 __LINE__, \ 45 __FUNCTION__); \ 46 STACK_TRACE; \ 47 } while (0) 48 49#if 0 50#if defined HAVE_EXECINFO_H 51#include <execinfo.h> 52#elif defined HAVE_LIBUNWIND_H 53#include <libunwind.h> 54#endif 55void 56dbg_stack_trace() 57{ 58 void *trace[32]; 59 size_t size; 60 61 size = backtrace(trace, sizeof(trace)/sizeof(trace[0])); 62 backtrace_symbols_fd(trace, size, 1); 63} 64 65#undef STACK_TRACE 66#define STACK_TRACE do { dbg_stack_trace(); } while (0) 67#endif 68 69/* 70 ***************************************************************************** 71 * BLAS/LAPACK calling macros * 72 ***************************************************************************** 73 */ 74 75#define FNAME(x) BLAS_FUNC(x) 76 77typedef CBLAS_INT fortran_int; 78 79typedef struct { float r, i; } f2c_complex; 80typedef struct { double r, i; } f2c_doublecomplex; 81/* typedef long int (*L_fp)(); */ 82 83typedef float fortran_real; 84typedef double fortran_doublereal; 85typedef f2c_complex fortran_complex; 86typedef f2c_doublecomplex fortran_doublecomplex; 87 88extern "C" fortran_int 89FNAME(sgeev)(char *jobvl, char *jobvr, fortran_int *n, 90 float a[], fortran_int *lda, float wr[], float wi[], 91 float vl[], fortran_int *ldvl, float vr[], fortran_int *ldvr, 92 float work[], fortran_int lwork[], 93 fortran_int *info); 94extern "C" fortran_int 95FNAME(dgeev)(char *jobvl, char *jobvr, fortran_int *n, 96 double a[], fortran_int *lda, double wr[], double wi[], 97 double vl[], fortran_int *ldvl, double vr[], fortran_int *ldvr, 98 double work[], fortran_int lwork[], 99 fortran_int *info); 100extern "C" fortran_int 101FNAME(cgeev)(char *jobvl, char *jobvr, fortran_int *n, 102 f2c_complex a[], fortran_int *lda, 103 f2c_complex w[], 104 f2c_complex vl[], fortran_int *ldvl, 105 f2c_complex vr[], fortran_int *ldvr, 106 f2c_complex work[], fortran_int *lwork, 107 float rwork[], 108 fortran_int *info); 109extern "C" fortran_int 110FNAME(zgeev)(char *jobvl, char *jobvr, fortran_int *n, 111 f2c_doublecomplex a[], fortran_int *lda, 112 f2c_doublecomplex w[], 113 f2c_doublecomplex vl[], fortran_int *ldvl, 114 f2c_doublecomplex vr[], fortran_int *ldvr, 115 f2c_doublecomplex work[], fortran_int *lwork, 116 double rwork[], 117 fortran_int *info); 118 119extern "C" fortran_int 120FNAME(ssyevd)(char *jobz, char *uplo, fortran_int *n, 121 float a[], fortran_int *lda, float w[], float work[], 122 fortran_int *lwork, fortran_int iwork[], fortran_int *liwork, 123 fortran_int *info); 124extern "C" fortran_int 125FNAME(dsyevd)(char *jobz, char *uplo, fortran_int *n, 126 double a[], fortran_int *lda, double w[], double work[], 127 fortran_int *lwork, fortran_int iwork[], fortran_int *liwork, 128 fortran_int *info); 129extern "C" fortran_int 130FNAME(cheevd)(char *jobz, char *uplo, fortran_int *n, 131 f2c_complex a[], fortran_int *lda, 132 float w[], f2c_complex work[], 133 fortran_int *lwork, float rwork[], fortran_int *lrwork, fortran_int iwork[], 134 fortran_int *liwork, 135 fortran_int *info); 136extern "C" fortran_int 137FNAME(zheevd)(char *jobz, char *uplo, fortran_int *n, 138 f2c_doublecomplex a[], fortran_int *lda, 139 double w[], f2c_doublecomplex work[], 140 fortran_int *lwork, double rwork[], fortran_int *lrwork, fortran_int iwork[], 141 fortran_int *liwork, 142 fortran_int *info); 143 144extern "C" fortran_int 145FNAME(sgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, 146 float a[], fortran_int *lda, float b[], fortran_int *ldb, 147 float s[], float *rcond, fortran_int *rank, 148 float work[], fortran_int *lwork, fortran_int iwork[], 149 fortran_int *info); 150extern "C" fortran_int 151FNAME(dgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, 152 double a[], fortran_int *lda, double b[], fortran_int *ldb, 153 double s[], double *rcond, fortran_int *rank, 154 double work[], fortran_int *lwork, fortran_int iwork[], 155 fortran_int *info); 156extern "C" fortran_int 157FNAME(cgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, 158 f2c_complex a[], fortran_int *lda, 159 f2c_complex b[], fortran_int *ldb, 160 float s[], float *rcond, fortran_int *rank, 161 f2c_complex work[], fortran_int *lwork, 162 float rwork[], fortran_int iwork[], 163 fortran_int *info); 164extern "C" fortran_int 165FNAME(zgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, 166 f2c_doublecomplex a[], fortran_int *lda, 167 f2c_doublecomplex b[], fortran_int *ldb, 168 double s[], double *rcond, fortran_int *rank, 169 f2c_doublecomplex work[], fortran_int *lwork, 170 double rwork[], fortran_int iwork[], 171 fortran_int *info); 172 173extern "C" fortran_int 174FNAME(dgeqrf)(fortran_int *m, fortran_int *n, double a[], fortran_int *lda, 175 double tau[], double work[], 176 fortran_int *lwork, fortran_int *info); 177extern "C" fortran_int 178FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int *lda, 179 f2c_doublecomplex tau[], f2c_doublecomplex work[], 180 fortran_int *lwork, fortran_int *info); 181 182extern "C" fortran_int 183FNAME(dorgqr)(fortran_int *m, fortran_int *n, fortran_int *k, double a[], fortran_int *lda, 184 double tau[], double work[], 185 fortran_int *lwork, fortran_int *info); 186extern "C" fortran_int 187FNAME(zungqr)(fortran_int *m, fortran_int *n, fortran_int *k, f2c_doublecomplex a[], 188 fortran_int *lda, f2c_doublecomplex tau[], 189 f2c_doublecomplex work[], fortran_int *lwork, fortran_int *info); 190 191extern "C" fortran_int 192FNAME(sgesv)(fortran_int *n, fortran_int *nrhs, 193 float a[], fortran_int *lda, 194 fortran_int ipiv[], 195 float b[], fortran_int *ldb, 196 fortran_int *info); 197extern "C" fortran_int 198FNAME(dgesv)(fortran_int *n, fortran_int *nrhs, 199 double a[], fortran_int *lda, 200 fortran_int ipiv[], 201 double b[], fortran_int *ldb, 202 fortran_int *info); 203extern "C" fortran_int 204FNAME(cgesv)(fortran_int *n, fortran_int *nrhs, 205 f2c_complex a[], fortran_int *lda, 206 fortran_int ipiv[], 207 f2c_complex b[], fortran_int *ldb, 208 fortran_int *info); 209extern "C" fortran_int 210FNAME(zgesv)(fortran_int *n, fortran_int *nrhs, 211 f2c_doublecomplex a[], fortran_int *lda, 212 fortran_int ipiv[], 213 f2c_doublecomplex b[], fortran_int *ldb, 214 fortran_int *info); 215 216extern "C" fortran_int 217FNAME(sgetrf)(fortran_int *m, fortran_int *n, 218 float a[], fortran_int *lda, 219 fortran_int ipiv[], 220 fortran_int *info); 221extern "C" fortran_int 222FNAME(dgetrf)(fortran_int *m, fortran_int *n, 223 double a[], fortran_int *lda, 224 fortran_int ipiv[], 225 fortran_int *info); 226extern "C" fortran_int 227FNAME(cgetrf)(fortran_int *m, fortran_int *n, 228 f2c_complex a[], fortran_int *lda, 229 fortran_int ipiv[], 230 fortran_int *info); 231extern "C" fortran_int 232FNAME(zgetrf)(fortran_int *m, fortran_int *n, 233 f2c_doublecomplex a[], fortran_int *lda, 234 fortran_int ipiv[], 235 fortran_int *info); 236 237extern "C" fortran_int 238FNAME(spotrf)(char *uplo, fortran_int *n, 239 float a[], fortran_int *lda, 240 fortran_int *info); 241extern "C" fortran_int 242FNAME(dpotrf)(char *uplo, fortran_int *n, 243 double a[], fortran_int *lda, 244 fortran_int *info); 245extern "C" fortran_int 246FNAME(cpotrf)(char *uplo, fortran_int *n, 247 f2c_complex a[], fortran_int *lda, 248 fortran_int *info); 249extern "C" fortran_int 250FNAME(zpotrf)(char *uplo, fortran_int *n, 251 f2c_doublecomplex a[], fortran_int *lda, 252 fortran_int *info); 253 254extern "C" fortran_int 255FNAME(sgesdd)(char *jobz, fortran_int *m, fortran_int *n, 256 float a[], fortran_int *lda, float s[], float u[], 257 fortran_int *ldu, float vt[], fortran_int *ldvt, float work[], 258 fortran_int *lwork, fortran_int iwork[], fortran_int *info); 259extern "C" fortran_int 260FNAME(dgesdd)(char *jobz, fortran_int *m, fortran_int *n, 261 double a[], fortran_int *lda, double s[], double u[], 262 fortran_int *ldu, double vt[], fortran_int *ldvt, double work[], 263 fortran_int *lwork, fortran_int iwork[], fortran_int *info); 264extern "C" fortran_int 265FNAME(cgesdd)(char *jobz, fortran_int *m, fortran_int *n, 266 f2c_complex a[], fortran_int *lda, 267 float s[], f2c_complex u[], fortran_int *ldu, 268 f2c_complex vt[], fortran_int *ldvt, 269 f2c_complex work[], fortran_int *lwork, 270 float rwork[], fortran_int iwork[], fortran_int *info); 271extern "C" fortran_int 272FNAME(zgesdd)(char *jobz, fortran_int *m, fortran_int *n, 273 f2c_doublecomplex a[], fortran_int *lda, 274 double s[], f2c_doublecomplex u[], fortran_int *ldu, 275 f2c_doublecomplex vt[], fortran_int *ldvt, 276 f2c_doublecomplex work[], fortran_int *lwork, 277 double rwork[], fortran_int iwork[], fortran_int *info); 278 279extern "C" fortran_int 280FNAME(spotrs)(char *uplo, fortran_int *n, fortran_int *nrhs, 281 float a[], fortran_int *lda, 282 float b[], fortran_int *ldb, 283 fortran_int *info); 284extern "C" fortran_int 285FNAME(dpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs, 286 double a[], fortran_int *lda, 287 double b[], fortran_int *ldb, 288 fortran_int *info); 289extern "C" fortran_int 290FNAME(cpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs, 291 f2c_complex a[], fortran_int *lda, 292 f2c_complex b[], fortran_int *ldb, 293 fortran_int *info); 294extern "C" fortran_int 295FNAME(zpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs, 296 f2c_doublecomplex a[], fortran_int *lda, 297 f2c_doublecomplex b[], fortran_int *ldb, 298 fortran_int *info); 299 300extern "C" fortran_int 301FNAME(spotri)(char *uplo, fortran_int *n, 302 float a[], fortran_int *lda, 303 fortran_int *info); 304extern "C" fortran_int 305FNAME(dpotri)(char *uplo, fortran_int *n, 306 double a[], fortran_int *lda, 307 fortran_int *info); 308extern "C" fortran_int 309FNAME(cpotri)(char *uplo, fortran_int *n, 310 f2c_complex a[], fortran_int *lda, 311 fortran_int *info); 312extern "C" fortran_int 313FNAME(zpotri)(char *uplo, fortran_int *n, 314 f2c_doublecomplex a[], fortran_int *lda, 315 fortran_int *info); 316 317extern "C" fortran_int 318FNAME(scopy)(fortran_int *n, 319 float *sx, fortran_int *incx, 320 float *sy, fortran_int *incy); 321extern "C" fortran_int 322FNAME(dcopy)(fortran_int *n, 323 double *sx, fortran_int *incx, 324 double *sy, fortran_int *incy); 325extern "C" fortran_int 326FNAME(ccopy)(fortran_int *n, 327 f2c_complex *sx, fortran_int *incx, 328 f2c_complex *sy, fortran_int *incy); 329extern "C" fortran_int 330FNAME(zcopy)(fortran_int *n, 331 f2c_doublecomplex *sx, fortran_int *incx, 332 f2c_doublecomplex *sy, fortran_int *incy); 333 334extern "C" float 335FNAME(sdot)(fortran_int *n, 336 float *sx, fortran_int *incx, 337 float *sy, fortran_int *incy); 338extern "C" double 339FNAME(ddot)(fortran_int *n, 340 double *sx, fortran_int *incx, 341 double *sy, fortran_int *incy); 342extern "C" void 343FNAME(cdotu)(f2c_complex *ret, fortran_int *n, 344 f2c_complex *sx, fortran_int *incx, 345 f2c_complex *sy, fortran_int *incy); 346extern "C" void 347FNAME(zdotu)(f2c_doublecomplex *ret, fortran_int *n, 348 f2c_doublecomplex *sx, fortran_int *incx, 349 f2c_doublecomplex *sy, fortran_int *incy); 350extern "C" void 351FNAME(cdotc)(f2c_complex *ret, fortran_int *n, 352 f2c_complex *sx, fortran_int *incx, 353 f2c_complex *sy, fortran_int *incy); 354extern "C" void 355FNAME(zdotc)(f2c_doublecomplex *ret, fortran_int *n, 356 f2c_doublecomplex *sx, fortran_int *incx, 357 f2c_doublecomplex *sy, fortran_int *incy); 358 359extern "C" fortran_int 360FNAME(sgemm)(char *transa, char *transb, 361 fortran_int *m, fortran_int *n, fortran_int *k, 362 float *alpha, 363 float *a, fortran_int *lda, 364 float *b, fortran_int *ldb, 365 float *beta, 366 float *c, fortran_int *ldc); 367extern "C" fortran_int 368FNAME(dgemm)(char *transa, char *transb, 369 fortran_int *m, fortran_int *n, fortran_int *k, 370 double *alpha, 371 double *a, fortran_int *lda, 372 double *b, fortran_int *ldb, 373 double *beta, 374 double *c, fortran_int *ldc); 375extern "C" fortran_int 376FNAME(cgemm)(char *transa, char *transb, 377 fortran_int *m, fortran_int *n, fortran_int *k, 378 f2c_complex *alpha, 379 f2c_complex *a, fortran_int *lda, 380 f2c_complex *b, fortran_int *ldb, 381 f2c_complex *beta, 382 f2c_complex *c, fortran_int *ldc); 383extern "C" fortran_int 384FNAME(zgemm)(char *transa, char *transb, 385 fortran_int *m, fortran_int *n, fortran_int *k, 386 f2c_doublecomplex *alpha, 387 f2c_doublecomplex *a, fortran_int *lda, 388 f2c_doublecomplex *b, fortran_int *ldb, 389 f2c_doublecomplex *beta, 390 f2c_doublecomplex *c, fortran_int *ldc); 391 392 393#define LAPACK_T(FUNC) \ 394 TRACE_TXT("Calling LAPACK ( " # FUNC " )\n"); \ 395 FNAME(FUNC) 396 397#define BLAS(FUNC) \ 398 FNAME(FUNC) 399 400#define LAPACK(FUNC) \ 401 FNAME(FUNC) 402 403 404/* 405 ***************************************************************************** 406 ** Some handy functions ** 407 ***************************************************************************** 408 */ 409 410static inline int 411get_fp_invalid_and_clear(void) 412{ 413 int status; 414 status = npy_clear_floatstatus_barrier((char*)&status); 415 return !!(status & NPY_FPE_INVALID); 416} 417 418static inline void 419set_fp_invalid_or_clear(int error_occurred) 420{ 421 if (error_occurred) { 422 npy_set_floatstatus_invalid(); 423 } 424 else { 425 npy_clear_floatstatus_barrier((char*)&error_occurred); 426 } 427} 428 429/* 430 ***************************************************************************** 431 ** Some handy constants ** 432 ***************************************************************************** 433 */ 434 435#define UMATH_LINALG_MODULE_NAME "_umath_linalg" 436 437template<typename T> 438struct numeric_limits; 439 440template<> 441struct numeric_limits<float> { 442static constexpr float one = 1.0f; 443static constexpr float zero = 0.0f; 444static constexpr float minus_one = -1.0f; 445static const float ninf; 446static const float nan; 447}; 448constexpr float numeric_limits<float>::one; 449constexpr float numeric_limits<float>::zero; 450constexpr float numeric_limits<float>::minus_one; 451const float numeric_limits<float>::ninf = -NPY_INFINITYF; 452const float numeric_limits<float>::nan = NPY_NANF; 453 454template<> 455struct numeric_limits<double> { 456static constexpr double one = 1.0; 457static constexpr double zero = 0.0; 458static constexpr double minus_one = -1.0; 459static const double ninf; 460static const double nan; 461}; 462constexpr double numeric_limits<double>::one; 463constexpr double numeric_limits<double>::zero; 464constexpr double numeric_limits<double>::minus_one; 465const double numeric_limits<double>::ninf = -NPY_INFINITY; 466const double numeric_limits<double>::nan = NPY_NAN; 467 468#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) 469template<> 470struct numeric_limits<npy_cfloat> { 471static constexpr npy_cfloat one = {1.0f, 0.0f}; 472static constexpr npy_cfloat zero = {0.0f, 0.0f}; 473static constexpr npy_cfloat minus_one = {-1.0f, 0.0f}; 474static const npy_cfloat ninf; 475static const npy_cfloat nan; 476}; 477constexpr npy_cfloat numeric_limits<npy_cfloat>::one; 478constexpr npy_cfloat numeric_limits<npy_cfloat>::zero; 479constexpr npy_cfloat numeric_limits<npy_cfloat>::minus_one; 480const npy_cfloat numeric_limits<npy_cfloat>::ninf = {-NPY_INFINITYF, 0.0f}; 481const npy_cfloat numeric_limits<npy_cfloat>::nan = {NPY_NANF, NPY_NANF}; 482#else 483template<> 484struct numeric_limits<npy_cfloat> { 485static constexpr npy_cfloat one = 1.0f; 486static constexpr npy_cfloat zero = 0.0f; 487static constexpr npy_cfloat minus_one = -1.0f; 488static const npy_cfloat ninf; 489static const npy_cfloat nan; 490}; 491constexpr npy_cfloat numeric_limits<npy_cfloat>::one; 492constexpr npy_cfloat numeric_limits<npy_cfloat>::zero; 493constexpr npy_cfloat numeric_limits<npy_cfloat>::minus_one; 494const npy_cfloat numeric_limits<npy_cfloat>::ninf = -NPY_INFINITYF; 495const npy_cfloat numeric_limits<npy_cfloat>::nan = NPY_NANF; 496#endif 497 498template<> 499struct numeric_limits<f2c_complex> { 500static constexpr f2c_complex one = {1.0f, 0.0f}; 501static constexpr f2c_complex zero = {0.0f, 0.0f}; 502static constexpr f2c_complex minus_one = {-1.0f, 0.0f}; 503static const f2c_complex ninf; 504static const f2c_complex nan; 505}; 506constexpr f2c_complex numeric_limits<f2c_complex>::one; 507constexpr f2c_complex numeric_limits<f2c_complex>::zero; 508constexpr f2c_complex numeric_limits<f2c_complex>::minus_one; 509const f2c_complex numeric_limits<f2c_complex>::ninf = {-NPY_INFINITYF, 0.0f}; 510const f2c_complex numeric_limits<f2c_complex>::nan = {NPY_NANF, NPY_NANF}; 511 512#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) 513template<> 514struct numeric_limits<npy_cdouble> { 515static constexpr npy_cdouble one = {1.0, 0.0}; 516static constexpr npy_cdouble zero = {0.0, 0.0}; 517static constexpr npy_cdouble minus_one = {-1.0, 0.0}; 518static const npy_cdouble ninf; 519static const npy_cdouble nan; 520}; 521constexpr npy_cdouble numeric_limits<npy_cdouble>::one; 522constexpr npy_cdouble numeric_limits<npy_cdouble>::zero; 523constexpr npy_cdouble numeric_limits<npy_cdouble>::minus_one; 524const npy_cdouble numeric_limits<npy_cdouble>::ninf = {-NPY_INFINITY, 0.0}; 525const npy_cdouble numeric_limits<npy_cdouble>::nan = {NPY_NAN, NPY_NAN}; 526#else 527template<> 528struct numeric_limits<npy_cdouble> { 529static constexpr npy_cdouble one = 1.0; 530static constexpr npy_cdouble zero = 0.0; 531static constexpr npy_cdouble minus_one = -1.0; 532static const npy_cdouble ninf; 533static const npy_cdouble nan; 534}; 535constexpr npy_cdouble numeric_limits<npy_cdouble>::one; 536constexpr npy_cdouble numeric_limits<npy_cdouble>::zero; 537constexpr npy_cdouble numeric_limits<npy_cdouble>::minus_one; 538const npy_cdouble numeric_limits<npy_cdouble>::ninf = -NPY_INFINITY; 539const npy_cdouble numeric_limits<npy_cdouble>::nan = NPY_NAN; 540#endif 541 542template<> 543struct numeric_limits<f2c_doublecomplex> { 544static constexpr f2c_doublecomplex one = {1.0, 0.0}; 545static constexpr f2c_doublecomplex zero = {0.0, 0.0}; 546static constexpr f2c_doublecomplex minus_one = {-1.0, 0.0}; 547static const f2c_doublecomplex ninf; 548static const f2c_doublecomplex nan; 549}; 550constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::one; 551constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::zero; 552constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::minus_one; 553const f2c_doublecomplex numeric_limits<f2c_doublecomplex>::ninf = {-NPY_INFINITY, 0.0}; 554const f2c_doublecomplex numeric_limits<f2c_doublecomplex>::nan = {NPY_NAN, NPY_NAN}; 555 556/* 557 ***************************************************************************** 558 ** Structs used for data rearrangement ** 559 ***************************************************************************** 560 */ 561 562 563/* 564 * this struct contains information about how to linearize a matrix in a local 565 * buffer so that it can be used by blas functions. All strides are specified 566 * in bytes and are converted to elements later in type specific functions. 567 * 568 * rows: number of rows in the matrix 569 * columns: number of columns in the matrix 570 * row_strides: the number bytes between consecutive rows. 571 * column_strides: the number of bytes between consecutive columns. 572 * output_lead_dim: BLAS/LAPACK-side leading dimension, in elements 573 */ 574typedef struct linearize_data_struct 575{ 576 npy_intp rows; 577 npy_intp columns; 578 npy_intp row_strides; 579 npy_intp column_strides; 580 npy_intp output_lead_dim; 581} LINEARIZE_DATA_t; 582 583static inline void 584init_linearize_data_ex(LINEARIZE_DATA_t *lin_data, 585 npy_intp rows, 586 npy_intp columns, 587 npy_intp row_strides, 588 npy_intp column_strides, 589 npy_intp output_lead_dim) 590{ 591 lin_data->rows = rows; 592 lin_data->columns = columns; 593 lin_data->row_strides = row_strides; 594 lin_data->column_strides = column_strides; 595 lin_data->output_lead_dim = output_lead_dim; 596} 597 598static inline void 599init_linearize_data(LINEARIZE_DATA_t *lin_data, 600 npy_intp rows, 601 npy_intp columns, 602 npy_intp row_strides, 603 npy_intp column_strides) 604{ 605 init_linearize_data_ex( 606 lin_data, rows, columns, row_strides, column_strides, columns); 607} 608 609static inline void 610dump_ufunc_object(PyUFuncObject* ufunc) 611{ 612 TRACE_TXT("\n\n%s '%s' (%d input(s), %d output(s), %d specialization(s).\n", 613 ufunc->core_enabled? "generalized ufunc" : "scalar ufunc", 614 ufunc->name, ufunc->nin, ufunc->nout, ufunc->ntypes); 615 if (ufunc->core_enabled) { 616 int arg; 617 int dim; 618 TRACE_TXT("\t%s (%d dimension(s) detected).\n", 619 ufunc->core_signature, ufunc->core_num_dim_ix); 620 621 for (arg = 0; arg < ufunc->nargs; arg++){ 622 int * arg_dim_ix = ufunc->core_dim_ixs + ufunc->core_offsets[arg]; 623 TRACE_TXT("\t\targ %d (%s) has %d dimension(s): (", 624 arg, arg < ufunc->nin? "INPUT" : "OUTPUT", 625 ufunc->core_num_dims[arg]); 626 for (dim = 0; dim < ufunc->core_num_dims[arg]; dim ++) { 627 TRACE_TXT(" %d", arg_dim_ix[dim]); 628 } 629 TRACE_TXT(" )\n"); 630 } 631 } 632} 633 634static inline void 635dump_linearize_data(const char* name, const LINEARIZE_DATA_t* params) 636{ 637 TRACE_TXT("\n\t%s rows: %zd columns: %zd"\ 638 "\n\t\trow_strides: %td column_strides: %td"\ 639 "\n", name, params->rows, params->columns, 640 params->row_strides, params->column_strides); 641} 642 643static inline void 644print(npy_float s) 645{ 646 TRACE_TXT(" %8.4f", s); 647} 648static inline void 649print(npy_double d) 650{ 651 TRACE_TXT(" %10.6f", d); 652} 653static inline void 654print(npy_cfloat c) 655{ 656 float* c_parts = (float*)&c; 657 TRACE_TXT("(%8.4f, %8.4fj)", c_parts[0], c_parts[1]); 658} 659static inline void 660print(npy_cdouble z) 661{ 662 double* z_parts = (double*)&z; 663 TRACE_TXT("(%8.4f, %8.4fj)", z_parts[0], z_parts[1]); 664} 665 666template<typename typ> 667static inline void 668dump_matrix(const char* name, 669 size_t rows, size_t columns, 670 const typ* ptr) 671{ 672 size_t i, j; 673 674 TRACE_TXT("\n%s %p (%zd, %zd)\n", name, ptr, rows, columns); 675 for (i = 0; i < rows; i++) 676 { 677 TRACE_TXT("| "); 678 for (j = 0; j < columns; j++) 679 { 680 print(ptr[j*rows + i]); 681 TRACE_TXT(", "); 682 } 683 TRACE_TXT(" |\n"); 684 } 685} 686 687 688/* 689 ***************************************************************************** 690 ** Basics ** 691 ***************************************************************************** 692 */ 693 694static inline fortran_int 695fortran_int_min(fortran_int x, fortran_int y) { 696 return x < y ? x : y; 697} 698 699static inline fortran_int 700fortran_int_max(fortran_int x, fortran_int y) { 701 return x > y ? x : y; 702} 703 704#define INIT_OUTER_LOOP_1 \ 705 npy_intp dN = *dimensions++;\ 706 npy_intp N_;\ 707 npy_intp s0 = *steps++; 708 709#define INIT_OUTER_LOOP_2 \ 710 INIT_OUTER_LOOP_1\ 711 npy_intp s1 = *steps++; 712 713#define INIT_OUTER_LOOP_3 \ 714 INIT_OUTER_LOOP_2\ 715 npy_intp s2 = *steps++; 716 717#define INIT_OUTER_LOOP_4 \ 718 INIT_OUTER_LOOP_3\ 719 npy_intp s3 = *steps++; 720 721#define INIT_OUTER_LOOP_5 \ 722 INIT_OUTER_LOOP_4\ 723 npy_intp s4 = *steps++; 724 725#define INIT_OUTER_LOOP_6 \ 726 INIT_OUTER_LOOP_5\ 727 npy_intp s5 = *steps++; 728 729#define INIT_OUTER_LOOP_7 \ 730 INIT_OUTER_LOOP_6\ 731 npy_intp s6 = *steps++; 732 733#define BEGIN_OUTER_LOOP_2 \ 734 for (N_ = 0;\ 735 N_ < dN;\ 736 N_++, args[0] += s0,\ 737 args[1] += s1) { 738 739#define BEGIN_OUTER_LOOP_3 \ 740 for (N_ = 0;\ 741 N_ < dN;\ 742 N_++, args[0] += s0,\ 743 args[1] += s1,\ 744 args[2] += s2) { 745 746#define BEGIN_OUTER_LOOP_4 \ 747 for (N_ = 0;\ 748 N_ < dN;\ 749 N_++, args[0] += s0,\ 750 args[1] += s1,\ 751 args[2] += s2,\ 752 args[3] += s3) { 753 754#define BEGIN_OUTER_LOOP_5 \ 755 for (N_ = 0;\ 756 N_ < dN;\ 757 N_++, args[0] += s0,\ 758 args[1] += s1,\ 759 args[2] += s2,\ 760 args[3] += s3,\ 761 args[4] += s4) { 762 763#define BEGIN_OUTER_LOOP_6 \ 764 for (N_ = 0;\ 765 N_ < dN;\ 766 N_++, args[0] += s0,\ 767 args[1] += s1,\ 768 args[2] += s2,\ 769 args[3] += s3,\ 770 args[4] += s4,\ 771 args[5] += s5) { 772 773#define BEGIN_OUTER_LOOP_7 \ 774 for (N_ = 0;\ 775 N_ < dN;\ 776 N_++, args[0] += s0,\ 777 args[1] += s1,\ 778 args[2] += s2,\ 779 args[3] += s3,\ 780 args[4] += s4,\ 781 args[5] += s5,\ 782 args[6] += s6) { 783 784#define END_OUTER_LOOP } 785 786static inline void 787update_pointers(npy_uint8** bases, ptrdiff_t* offsets, size_t count) 788{ 789 size_t i; 790 for (i = 0; i < count; ++i) { 791 bases[i] += offsets[i]; 792 } 793} 794 795 796/* disable -Wmaybe-uninitialized as there is some code that generate false 797 positives with this warning 798*/ 799#pragma GCC diagnostic push 800#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" 801 802/* 803 ***************************************************************************** 804 ** DISPATCHER FUNCS ** 805 ***************************************************************************** 806 */ 807static fortran_int copy(fortran_int *n, 808 float *sx, fortran_int *incx, 809 float *sy, fortran_int *incy) { return FNAME(scopy)(n, sx, incx, 810 sy, incy); 811} 812static fortran_int copy(fortran_int *n, 813 double *sx, fortran_int *incx, 814 double *sy, fortran_int *incy) { return FNAME(dcopy)(n, sx, incx, 815 sy, incy); 816} 817static fortran_int copy(fortran_int *n, 818 f2c_complex *sx, fortran_int *incx, 819 f2c_complex *sy, fortran_int *incy) { return FNAME(ccopy)(n, sx, incx, 820 sy, incy); 821} 822static fortran_int copy(fortran_int *n, 823 f2c_doublecomplex *sx, fortran_int *incx, 824 f2c_doublecomplex *sy, fortran_int *incy) { return FNAME(zcopy)(n, sx, incx, 825 sy, incy); 826} 827 828static fortran_int getrf(fortran_int *m, fortran_int *n, float a[], fortran_int 829*lda, fortran_int ipiv[], fortran_int *info) { 830 return LAPACK(sgetrf)(m, n, a, lda, ipiv, info); 831} 832static fortran_int getrf(fortran_int *m, fortran_int *n, double a[], fortran_int 833*lda, fortran_int ipiv[], fortran_int *info) { 834 return LAPACK(dgetrf)(m, n, a, lda, ipiv, info); 835} 836static fortran_int getrf(fortran_int *m, fortran_int *n, f2c_complex a[], fortran_int 837*lda, fortran_int ipiv[], fortran_int *info) { 838 return LAPACK(cgetrf)(m, n, a, lda, ipiv, info); 839} 840static fortran_int getrf(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int 841*lda, fortran_int ipiv[], fortran_int *info) { 842 return LAPACK(zgetrf)(m, n, a, lda, ipiv, info); 843} 844 845/* 846 ***************************************************************************** 847 ** HELPER FUNCS ** 848 ***************************************************************************** 849 */ 850template<typename T> 851struct fortran_type { 852using type = T; 853}; 854 855template<> struct fortran_type<npy_cfloat> { using type = f2c_complex;}; 856template<> struct fortran_type<npy_cdouble> { using type = f2c_doublecomplex;}; 857template<typename T> 858using fortran_type_t = typename fortran_type<T>::type; 859 860template<typename T> 861struct basetype { 862using type = T; 863}; 864template<> struct basetype<npy_cfloat> { using type = npy_float;}; 865template<> struct basetype<npy_cdouble> { using type = npy_double;}; 866template<> struct basetype<f2c_complex> { using type = fortran_real;}; 867template<> struct basetype<f2c_doublecomplex> { using type = fortran_doublereal;}; 868template<typename T> 869using basetype_t = typename basetype<T>::type; 870 871struct scalar_trait {}; 872struct complex_trait {}; 873template<typename typ> 874using dispatch_scalar = typename std::conditional<sizeof(basetype_t<typ>) == sizeof(typ), scalar_trait, complex_trait>::type; 875 876 877 /* rearranging of 2D matrices using blas */ 878 879template<typename typ> 880static inline void * 881linearize_matrix(typ *dst, 882 typ *src, 883 const LINEARIZE_DATA_t* data) 884{ 885 using ftyp = fortran_type_t<typ>; 886 if (dst) { 887 int i, j; 888 typ* rv = dst; 889 fortran_int columns = (fortran_int)data->columns; 890 fortran_int column_strides = 891 (fortran_int)(data->column_strides/sizeof(typ)); 892 fortran_int one = 1; 893 for (i = 0; i < data->rows; i++) { 894 if (column_strides > 0) { 895 copy(&columns, 896 (ftyp*)src, &column_strides, 897 (ftyp*)dst, &one); 898 } 899 else if (column_strides < 0) { 900 copy(&columns, 901 ((ftyp*)src + (columns-1)*column_strides), 902 &column_strides, 903 (ftyp*)dst, &one); 904 } 905 else { 906 /* 907 * Zero stride has undefined behavior in some BLAS 908 * implementations (e.g. OSX Accelerate), so do it 909 * manually 910 */ 911 for (j = 0; j < columns; ++j) { 912 memcpy(dst + j, src, sizeof(typ)); 913 } 914 } 915 src += data->row_strides/sizeof(typ); 916 dst += data->output_lead_dim; 917 } 918 return rv; 919 } else { 920 return src; 921 } 922} 923 924template<typename typ> 925static inline void * 926delinearize_matrix(typ *dst, 927 typ *src, 928 const LINEARIZE_DATA_t* data) 929{ 930using ftyp = fortran_type_t<typ>; 931 932 if (src) { 933 int i; 934 typ *rv = src; 935 fortran_int columns = (fortran_int)data->columns; 936 fortran_int column_strides = 937 (fortran_int)(data->column_strides/sizeof(typ)); 938 fortran_int one = 1; 939 for (i = 0; i < data->rows; i++) { 940 if (column_strides > 0) { 941 copy(&columns, 942 (ftyp*)src, &one, 943 (ftyp*)dst, &column_strides); 944 } 945 else if (column_strides < 0) { 946 copy(&columns, 947 (ftyp*)src, &one, 948 ((ftyp*)dst + (columns-1)*column_strides), 949 &column_strides); 950 } 951 else { 952 /* 953 * Zero stride has undefined behavior in some BLAS 954 * implementations (e.g. OSX Accelerate), so do it 955 * manually 956 */ 957 if (columns > 0) { 958 memcpy(dst, 959 src + (columns-1), 960 sizeof(typ)); 961 } 962 } 963 src += data->output_lead_dim; 964 dst += data->row_strides/sizeof(typ); 965 } 966 967 return rv; 968 } else { 969 return src; 970 } 971} 972 973template<typename typ> 974static inline void 975nan_matrix(typ *dst, const LINEARIZE_DATA_t* data) 976{ 977 int i, j; 978 for (i = 0; i < data->rows; i++) { 979 typ *cp = dst; 980 ptrdiff_t cs = data->column_strides/sizeof(typ); 981 for (j = 0; j < data->columns; ++j) { 982 *cp = numeric_limits<typ>::nan; 983 cp += cs; 984 } 985 dst += data->row_strides/sizeof(typ); 986 } 987} 988 989template<typename typ> 990static inline void 991zero_matrix(typ *dst, const LINEARIZE_DATA_t* data) 992{ 993 int i, j; 994 for (i = 0; i < data->rows; i++) { 995 typ *cp = dst; 996 ptrdiff_t cs = data->column_strides/sizeof(typ); 997 for (j = 0; j < data->columns; ++j) { 998 *cp = numeric_limits<typ>::zero; 999 cp += cs; 1000 } 1001 dst += data->row_strides/sizeof(typ); 1002 } 1003} 1004 1005 /* identity square matrix generation */ 1006template<typename typ> 1007static inline void 1008identity_matrix(typ *matrix, size_t n) 1009{ 1010 size_t i; 1011 /* in IEEE floating point, zeroes are represented as bitwise 0 */ 1012 memset((void *)matrix, 0, n*n*sizeof(typ)); 1013 1014 for (i = 0; i < n; ++i) 1015 { 1016 *matrix = numeric_limits<typ>::one; 1017 matrix += n+1; 1018 } 1019} 1020 1021 /* lower/upper triangular matrix using blas (in place) */ 1022 1023template<typename typ> 1024static inline void 1025triu_matrix(typ *matrix, size_t n) 1026{ 1027 size_t i, j; 1028 matrix += n; 1029 for (i = 1; i < n; ++i) { 1030 for (j = 0; j < i; ++j) { 1031 matrix[j] = numeric_limits<typ>::zero; 1032 } 1033 matrix += n; 1034 } 1035} 1036 1037 1038/* -------------------------------------------------------------------------- */ 1039 /* Determinants */ 1040 1041static npy_float npylog(npy_float f) { return npy_logf(f);} 1042static npy_double npylog(npy_double d) { return npy_log(d);} 1043static npy_float npyexp(npy_float f) { return npy_expf(f);} 1044static npy_double npyexp(npy_double d) { return npy_exp(d);} 1045 1046template<typename typ> 1047static inline void 1048slogdet_from_factored_diagonal(typ* src, 1049 fortran_int m, 1050 typ *sign, 1051 typ *logdet) 1052{ 1053 typ acc_sign = *sign; 1054 typ acc_logdet = numeric_limits<typ>::zero; 1055 int i; 1056 for (i = 0; i < m; i++) { 1057 typ abs_element = *src; 1058 if (abs_element < numeric_limits<typ>::zero) { 1059 acc_sign = -acc_sign; 1060 abs_element = -abs_element; 1061 } 1062 1063 acc_logdet += npylog(abs_element); 1064 src += m+1; 1065 } 1066 1067 *sign = acc_sign; 1068 *logdet = acc_logdet; 1069} 1070 1071template<typename typ> 1072static inline typ 1073det_from_slogdet(typ sign, typ logdet) 1074{ 1075 typ result = sign * npyexp(logdet); 1076 return result; 1077} 1078 1079 1080npy_float npyabs(npy_cfloat z) { return npy_cabsf(z);} 1081npy_double npyabs(npy_cdouble z) { return npy_cabs(z);} 1082 1083inline float RE(npy_cfloat *c) { return npy_crealf(*c); } 1084inline double RE(npy_cdouble *c) { return npy_creal(*c); } 1085#if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE 1086inline longdouble_t RE(npy_clongdouble *c) { return npy_creall(*c); } 1087#endif 1088inline float IM(npy_cfloat *c) { return npy_cimagf(*c); } 1089inline double IM(npy_cdouble *c) { return npy_cimag(*c); } 1090#if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE 1091inline longdouble_t IM(npy_clongdouble *c) { return npy_cimagl(*c); } 1092#endif 1093inline void SETRE(npy_cfloat *c, float real) { npy_csetrealf(c, real); } 1094inline void SETRE(npy_cdouble *c, double real) { npy_csetreal(c, real); } 1095#if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE 1096inline void SETRE(npy_clongdouble *c, double real) { npy_csetreall(c, real); } 1097#endif 1098inline void SETIM(npy_cfloat *c, float real) { npy_csetimagf(c, real); } 1099inline void SETIM(npy_cdouble *c, double real) { npy_csetimag(c, real); } 1100#if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE 1101inline void SETIM(npy_clongdouble *c, double real) { npy_csetimagl(c, real); } 1102#endif 1103 1104template<typename typ> 1105static inline typ 1106mult(typ op1, typ op2) 1107{ 1108 typ rv; 1109 1110 SETRE(&rv, RE(&op1)*RE(&op2) - IM(&op1)*IM(&op2)); 1111 SETIM(&rv, RE(&op1)*IM(&op2) + IM(&op1)*RE(&op2)); 1112 1113 return rv; 1114} 1115 1116 1117template<typename typ, typename basetyp> 1118static inline void 1119slogdet_from_factored_diagonal(typ* src, 1120 fortran_int m, 1121 typ *sign, 1122 basetyp *logdet) 1123{ 1124 int i; 1125 typ sign_acc = *sign; 1126 basetyp logdet_acc = numeric_limits<basetyp>::zero; 1127 1128 for (i = 0; i < m; i++) 1129 { 1130 basetyp abs_element = npyabs(*src); 1131 typ sign_element; 1132 SETRE(&sign_element, RE(src) / abs_element); 1133 SETIM(&sign_element, IM(src) / abs_element); 1134 1135 sign_acc = mult(sign_acc, sign_element); 1136 logdet_acc += npylog(abs_element); 1137 src += m + 1; 1138 } 1139 1140 *sign = sign_acc; 1141 *logdet = logdet_acc; 1142} 1143 1144template<typename typ, typename basetyp> 1145static inline typ 1146det_from_slogdet(typ sign, basetyp logdet) 1147{ 1148 typ tmp; 1149 SETRE(&tmp, npyexp(logdet)); 1150 SETIM(&tmp, numeric_limits<basetyp>::zero); 1151 return mult(sign, tmp); 1152} 1153 1154 1155/* As in the linalg package, the determinant is computed via LU factorization 1156 * using LAPACK. 1157 * slogdet computes sign + log(determinant). 1158 * det computes sign * exp(slogdet). 1159 */ 1160template<typename typ, typename basetyp> 1161static inline void 1162slogdet_single_element(fortran_int m, 1163 typ* src, 1164 fortran_int* pivots, 1165 typ *sign, 1166 basetyp *logdet) 1167{ 1168using ftyp = fortran_type_t<typ>; 1169 fortran_int info = 0; 1170 fortran_int lda = fortran_int_max(m, 1); 1171 int i; 1172 /* note: done in place */ 1173 getrf(&m, &m, (ftyp*)src, &lda, pivots, &info); 1174 1175 if (info == 0) { 1176 int change_sign = 0; 1177 /* note: fortran uses 1 based indexing */ 1178 for (i = 0; i < m; i++) 1179 { 1180 change_sign += (pivots[i] != (i+1)); 1181 } 1182 1183 *sign = (change_sign % 2)?numeric_limits<typ>::minus_one:numeric_limits<typ>::one; 1184 slogdet_from_factored_diagonal(src, m, sign, logdet); 1185 } else { 1186 /* 1187 if getrf fails, use 0 as sign and -inf as logdet 1188 */ 1189 *sign = numeric_limits<typ>::zero; 1190 *logdet = numeric_limits<basetyp>::ninf; 1191 } 1192} 1193 1194template<typename typ, typename basetyp> 1195static void 1196slogdet(char **args, 1197 npy_intp const *dimensions, 1198 npy_intp const *steps, 1199 void *NPY_UNUSED(func)) 1200{ 1201 fortran_int m; 1202 char *tmp_buff = NULL; 1203 size_t matrix_size; 1204 size_t pivot_size; 1205 size_t safe_m; 1206 /* notes: 1207 * matrix will need to be copied always, as factorization in lapack is 1208 * made inplace 1209 * matrix will need to be in column-major order, as expected by lapack 1210 * code (fortran) 1211 * always a square matrix 1212 * need to allocate memory for both, matrix_buffer and pivot buffer 1213 */ 1214 INIT_OUTER_LOOP_3 1215 m = (fortran_int) dimensions[0]; 1216 /* avoid empty malloc (buffers likely unused) and ensure m is `size_t` */ 1217 safe_m = m != 0 ? m : 1; 1218 matrix_size = safe_m * safe_m * sizeof(typ); 1219 pivot_size = safe_m * sizeof(fortran_int); 1220 tmp_buff = (char *)malloc(matrix_size + pivot_size); 1221 1222 if (tmp_buff) { 1223 LINEARIZE_DATA_t lin_data; 1224 /* swapped steps to get matrix in FORTRAN order */ 1225 init_linearize_data(&lin_data, m, m, steps[1], steps[0]); 1226 BEGIN_OUTER_LOOP_3 1227 linearize_matrix((typ*)tmp_buff, (typ*)args[0], &lin_data); 1228 slogdet_single_element(m, 1229 (typ*)tmp_buff, 1230 (fortran_int*)(tmp_buff+matrix_size), 1231 (typ*)args[1], 1232 (basetyp*)args[2]); 1233 END_OUTER_LOOP 1234 1235 free(tmp_buff); 1236 } 1237 else { 1238 /* TODO: Requires use of new ufunc API to indicate error return */ 1239 NPY_ALLOW_C_API_DEF 1240 NPY_ALLOW_C_API; 1241 PyErr_NoMemory(); 1242 NPY_DISABLE_C_API; 1243 } 1244} 1245 1246template<typename typ, typename basetyp> 1247static void 1248det(char **args, 1249 npy_intp const *dimensions, 1250 npy_intp const *steps, 1251 void *NPY_UNUSED(func)) 1252{ 1253 fortran_int m; 1254 char *tmp_buff; 1255 size_t matrix_size; 1256 size_t pivot_size; 1257 size_t safe_m; 1258 /* notes: 1259 * matrix will need to be copied always, as factorization in lapack is 1260 * made inplace 1261 * matrix will need to be in column-major order, as expected by lapack 1262 * code (fortran) 1263 * always a square matrix 1264 * need to allocate memory for both, matrix_buffer and pivot buffer 1265 */ 1266 INIT_OUTER_LOOP_2 1267 m = (fortran_int) dimensions[0]; 1268 /* avoid empty malloc (buffers likely unused) and ensure m is `size_t` */ 1269 safe_m = m != 0 ? m : 1; 1270 matrix_size = safe_m * safe_m * sizeof(typ); 1271 pivot_size = safe_m * sizeof(fortran_int); 1272 tmp_buff = (char *)malloc(matrix_size + pivot_size); 1273 1274 if (tmp_buff) { 1275 LINEARIZE_DATA_t lin_data; 1276 typ sign; 1277 basetyp logdet; 1278 /* swapped steps to get matrix in FORTRAN order */ 1279 init_linearize_data(&lin_data, m, m, steps[1], steps[0]); 1280 1281 BEGIN_OUTER_LOOP_2 1282 linearize_matrix((typ*)tmp_buff, (typ*)args[0], &lin_data); 1283 slogdet_single_element(m, 1284 (typ*)tmp_buff, 1285 (fortran_int*)(tmp_buff + matrix_size), 1286 &sign, 1287 &logdet); 1288 *(typ *)args[1] = det_from_slogdet(sign, logdet); 1289 END_OUTER_LOOP 1290 1291 free(tmp_buff); 1292 } 1293 else { 1294 /* TODO: Requires use of new ufunc API to indicate error return */ 1295 NPY_ALLOW_C_API_DEF 1296 NPY_ALLOW_C_API; 1297 PyErr_NoMemory(); 1298 NPY_DISABLE_C_API; 1299 } 1300} 1301 1302 1303/* -------------------------------------------------------------------------- */ 1304 /* Eigh family */ 1305 1306template<typename typ> 1307struct EIGH_PARAMS_t { 1308 typ *A; /* matrix */ 1309 basetype_t<typ> *W; /* eigenvalue vector */ 1310 typ *WORK; /* main work buffer */ 1311 basetype_t<typ> *RWORK; /* secondary work buffer (for complex versions) */ 1312 fortran_int *IWORK; 1313 fortran_int N; 1314 fortran_int LWORK; 1315 fortran_int LRWORK; 1316 fortran_int LIWORK; 1317 char JOBZ; 1318 char UPLO; 1319 fortran_int LDA; 1320} ; 1321 1322static inline fortran_int 1323call_evd(EIGH_PARAMS_t<npy_float> *params) 1324{ 1325 fortran_int rv; 1326 LAPACK(ssyevd)(&params->JOBZ, &params->UPLO, &params->N, 1327 params->A, &params->LDA, params->W, 1328 params->WORK, &params->LWORK, 1329 params->IWORK, &params->LIWORK, 1330 &rv); 1331 return rv; 1332} 1333static inline fortran_int 1334call_evd(EIGH_PARAMS_t<npy_double> *params) 1335{ 1336 fortran_int rv; 1337 LAPACK(dsyevd)(&params->JOBZ, &params->UPLO, &params->N, 1338 params->A, &params->LDA, params->W, 1339 params->WORK, &params->LWORK, 1340 params->IWORK, &params->LIWORK, 1341 &rv); 1342 return rv; 1343} 1344 1345 1346/* 1347 * Initialize the parameters to use in for the lapack function _syevd 1348 * Handles buffer allocation 1349 */ 1350template<typename typ> 1351static inline int 1352init_evd(EIGH_PARAMS_t<typ>* params, char JOBZ, char UPLO, 1353 fortran_int N, scalar_trait) 1354{ 1355 npy_uint8 *mem_buff = NULL; 1356 npy_uint8 *mem_buff2 = NULL; 1357 fortran_int lwork; 1358 fortran_int liwork; 1359 npy_uint8 *a, *w, *work, *iwork; 1360 size_t safe_N = N; 1361 size_t alloc_size = safe_N * (safe_N + 1) * sizeof(typ); 1362 fortran_int lda = fortran_int_max(N, 1); 1363 1364 mem_buff = (npy_uint8 *)malloc(alloc_size); 1365 1366 if (!mem_buff) { 1367 goto error; 1368 } 1369 a = mem_buff; 1370 w = mem_buff + safe_N * safe_N * sizeof(typ); 1371 1372 params->A = (typ*)a; 1373 params->W = (typ*)w; 1374 params->RWORK = NULL; /* unused */ 1375 params->N = N; 1376 params->LRWORK = 0; /* unused */ 1377 params->JOBZ = JOBZ; 1378 params->UPLO = UPLO; 1379 params->LDA = lda; 1380 1381 /* Work size query */ 1382 { 1383 typ query_work_size; 1384 fortran_int query_iwork_size; 1385 1386 params->LWORK = -1; 1387 params->LIWORK = -1; 1388 params->WORK = &query_work_size; 1389 params->IWORK = &query_iwork_size; 1390 1391 if (call_evd(params) != 0) { 1392 goto error; 1393 } 1394 1395 lwork = (fortran_int)query_work_size; 1396 liwork = query_iwork_size; 1397 } 1398 1399 mem_buff2 = (npy_uint8 *)malloc(lwork*sizeof(typ) + liwork*sizeof(fortran_int)); 1400 if (!mem_buff2) { 1401 goto error; 1402 } 1403 1404 work = mem_buff2; 1405 iwork = mem_buff2 + lwork*sizeof(typ); 1406 1407 params->LWORK = lwork; 1408 params->WORK = (typ*)work; 1409 params->LIWORK = liwork; 1410 params->IWORK = (fortran_int*)iwork; 1411 1412 return 1; 1413 1414 error: 1415 /* something failed */ 1416 memset(params, 0, sizeof(*params)); 1417 free(mem_buff2); 1418 free(mem_buff); 1419 1420 return 0; 1421} 1422 1423 1424static inline fortran_int 1425call_evd(EIGH_PARAMS_t<npy_cfloat> *params) 1426{ 1427 fortran_int rv; 1428 LAPACK(cheevd)(&params->JOBZ, &params->UPLO, &params->N, 1429 (fortran_type_t<npy_cfloat>*)params->A, &params->LDA, params->W, 1430 (fortran_type_t<npy_cfloat>*)params->WORK, &params->LWORK, 1431 params->RWORK, &params->LRWORK, 1432 params->IWORK, &params->LIWORK, 1433 &rv); 1434 return rv; 1435} 1436 1437static inline fortran_int 1438call_evd(EIGH_PARAMS_t<npy_cdouble> *params) 1439{ 1440 fortran_int rv; 1441 LAPACK(zheevd)(&params->JOBZ, &params->UPLO, &params->N, 1442 (fortran_type_t<npy_cdouble>*)params->A, &params->LDA, params->W, 1443 (fortran_type_t<npy_cdouble>*)params->WORK, &params->LWORK, 1444 params->RWORK, &params->LRWORK, 1445 params->IWORK, &params->LIWORK, 1446 &rv); 1447 return rv; 1448} 1449 1450template<typename typ> 1451static inline int 1452init_evd(EIGH_PARAMS_t<typ> *params, 1453 char JOBZ, 1454 char UPLO, 1455 fortran_int N, complex_trait) 1456{ 1457 using basetyp = basetype_t<typ>; 1458using ftyp = fortran_type_t<typ>; 1459using fbasetyp = fortran_type_t<basetyp>; 1460 npy_uint8 *mem_buff = NULL; 1461 npy_uint8 *mem_buff2 = NULL; 1462 fortran_int lwork; 1463 fortran_int lrwork; 1464 fortran_int liwork; 1465 npy_uint8 *a, *w, *work, *rwork, *iwork; 1466 size_t safe_N = N; 1467 fortran_int lda = fortran_int_max(N, 1); 1468 1469 mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(typ) + 1470 safe_N * sizeof(basetyp)); 1471 if (!mem_buff) { 1472 goto error; 1473 } 1474 a = mem_buff; 1475 w = mem_buff + safe_N * safe_N * sizeof(typ); 1476 1477 params->A = (typ*)a; 1478 params->W = (basetyp*)w; 1479 params->N = N; 1480 params->JOBZ = JOBZ; 1481 params->UPLO = UPLO; 1482 params->LDA = lda; 1483 1484 /* Work size query */ 1485 { 1486 ftyp query_work_size; 1487 fbasetyp query_rwork_size; 1488 fortran_int query_iwork_size; 1489 1490 params->LWORK = -1; 1491 params->LRWORK = -1; 1492 params->LIWORK = -1; 1493 params->WORK = (typ*)&query_work_size; 1494 params->RWORK = (basetyp*)&query_rwork_size; 1495 params->IWORK = &query_iwork_size; 1496 1497 if (call_evd(params) != 0) { 1498 goto error; 1499 } 1500 1501 lwork = (fortran_int)*(fbasetyp*)&query_work_size; 1502 lrwork = (fortran_int)query_rwork_size; 1503 liwork = query_iwork_size; 1504 } 1505 1506 mem_buff2 = (npy_uint8 *)malloc(lwork*sizeof(typ) + 1507 lrwork*sizeof(basetyp) + 1508 liwork*sizeof(fortran_int)); 1509 if (!mem_buff2) { 1510 goto error; 1511 } 1512 1513 work = mem_buff2; 1514 rwork = work + lwork*sizeof(typ); 1515 iwork = rwork + lrwork*sizeof(basetyp); 1516 1517 params->WORK = (typ*)work; 1518 params->RWORK = (basetyp*)rwork; 1519 params->IWORK = (fortran_int*)iwork; 1520 params->LWORK = lwork; 1521 params->LRWORK = lrwork; 1522 params->LIWORK = liwork; 1523 1524 return 1; 1525 1526 /* something failed */ 1527error: 1528 memset(params, 0, sizeof(*params)); 1529 free(mem_buff2); 1530 free(mem_buff); 1531 1532 return 0; 1533} 1534 1535/* 1536 * (M, M)->(M,)(M, M) 1537 * dimensions[1] -> M 1538 * args[0] -> A[in] 1539 * args[1] -> W 1540 * args[2] -> A[out] 1541 */ 1542 1543template<typename typ> 1544static inline void 1545release_evd(EIGH_PARAMS_t<typ> *params) 1546{ 1547 /* allocated memory in A and WORK */ 1548 free(params->A); 1549 free(params->WORK); 1550 memset(params, 0, sizeof(*params)); 1551} 1552 1553 1554template<typename typ> 1555static inline void 1556eigh_wrapper(char JOBZ, 1557 char UPLO, 1558 char**args, 1559 npy_intp const *dimensions, 1560 npy_intp const *steps) 1561{ 1562 using basetyp = basetype_t<typ>; 1563 ptrdiff_t outer_steps[3]; 1564 size_t iter; 1565 size_t outer_dim = *dimensions++; 1566 size_t op_count = (JOBZ=='N')?2:3; 1567 EIGH_PARAMS_t<typ> eigh_params; 1568 int error_occurred = get_fp_invalid_and_clear(); 1569 1570 for (iter = 0; iter < op_count; ++iter) { 1571 outer_steps[iter] = (ptrdiff_t) steps[iter]; 1572 } 1573 steps += op_count; 1574 1575 if (init_evd(&eigh_params, 1576 JOBZ, 1577 UPLO, 1578 (fortran_int)dimensions[0], dispatch_scalar<typ>())) { 1579 LINEARIZE_DATA_t matrix_in_ld; 1580 LINEARIZE_DATA_t eigenvectors_out_ld; 1581 LINEARIZE_DATA_t eigenvalues_out_ld; 1582 1583 init_linearize_data(&matrix_in_ld, 1584 eigh_params.N, eigh_params.N, 1585 steps[1], steps[0]); 1586 init_linearize_data(&eigenvalues_out_ld, 1587 1, eigh_params.N, 1588 0, steps[2]); 1589 if ('V' == eigh_params.JOBZ) { 1590 init_linearize_data(&eigenvectors_out_ld, 1591 eigh_params.N, eigh_params.N, 1592 steps[4], steps[3]); 1593 } 1594 1595 for (iter = 0; iter < outer_dim; ++iter) { 1596 int not_ok; 1597 /* copy the matrix in */ 1598 linearize_matrix((typ*)eigh_params.A, (typ*)args[0], &matrix_in_ld); 1599 not_ok = call_evd(&eigh_params); 1600 if (!not_ok) { 1601 /* lapack ok, copy result out */ 1602 delinearize_matrix((basetyp*)args[1], 1603 (basetyp*)eigh_params.W, 1604 &eigenvalues_out_ld); 1605 1606 if ('V' == eigh_params.JOBZ) { 1607 delinearize_matrix((typ*)args[2], 1608 (typ*)eigh_params.A, 1609 &eigenvectors_out_ld); 1610 } 1611 } else { 1612 /* lapack fail, set result to nan */ 1613 error_occurred = 1; 1614 nan_matrix((basetyp*)args[1], &eigenvalues_out_ld); 1615 if ('V' == eigh_params.JOBZ) { 1616 nan_matrix((typ*)args[2], &eigenvectors_out_ld); 1617 } 1618 } 1619 update_pointers((npy_uint8**)args, outer_steps, op_count); 1620 } 1621 1622 release_evd(&eigh_params); 1623 } 1624 1625 set_fp_invalid_or_clear(error_occurred); 1626} 1627 1628 1629template<typename typ> 1630static void 1631eighlo(char **args, 1632 npy_intp const *dimensions, 1633 npy_intp const *steps, 1634 void *NPY_UNUSED(func)) 1635{ 1636 eigh_wrapper<typ>('V', 'L', args, dimensions, steps); 1637} 1638 1639template<typename typ> 1640static void 1641eighup(char **args, 1642 npy_intp const *dimensions, 1643 npy_intp const *steps, 1644 void* NPY_UNUSED(func)) 1645{ 1646 eigh_wrapper<typ>('V', 'U', args, dimensions, steps); 1647} 1648 1649template<typename typ> 1650static void 1651eigvalshlo(char **args, 1652 npy_intp const *dimensions, 1653 npy_intp const *steps, 1654 void* NPY_UNUSED(func)) 1655{ 1656 eigh_wrapper<typ>('N', 'L', args, dimensions, steps); 1657} 1658 1659template<typename typ> 1660static void 1661eigvalshup(char **args, 1662 npy_intp const *dimensions, 1663 npy_intp const *steps, 1664 void* NPY_UNUSED(func)) 1665{ 1666 eigh_wrapper<typ>('N', 'U', args, dimensions, steps); 1667} 1668 1669/* -------------------------------------------------------------------------- */ 1670 /* Solve family (includes inv) */ 1671 1672template<typename typ> 1673struct GESV_PARAMS_t 1674{ 1675 typ *A; /* A is (N, N) of base type */ 1676 typ *B; /* B is (N, NRHS) of base type */ 1677 fortran_int * IPIV; /* IPIV is (N) */ 1678 1679 fortran_int N; 1680 fortran_int NRHS; 1681 fortran_int LDA; 1682 fortran_int LDB; 1683}; 1684 1685static inline fortran_int 1686call_gesv(GESV_PARAMS_t<fortran_real> *params) 1687{ 1688 fortran_int rv; 1689 LAPACK(sgesv)(&params->N, &params->NRHS, 1690 params->A, &params->LDA, 1691 params->IPIV, 1692 params->B, &params->LDB, 1693 &rv); 1694 return rv; 1695} 1696 1697static inline fortran_int 1698call_gesv(GESV_PARAMS_t<fortran_doublereal> *params) 1699{ 1700 fortran_int rv; 1701 LAPACK(dgesv)(&params->N, &params->NRHS, 1702 params->A, &params->LDA, 1703 params->IPIV, 1704 params->B, &params->LDB, 1705 &rv); 1706 return rv; 1707} 1708 1709static inline fortran_int 1710call_gesv(GESV_PARAMS_t<fortran_complex> *params) 1711{ 1712 fortran_int rv; 1713 LAPACK(cgesv)(&params->N, &params->NRHS, 1714 params->A, &params->LDA, 1715 params->IPIV, 1716 params->B, &params->LDB, 1717 &rv); 1718 return rv; 1719} 1720 1721static inline fortran_int 1722call_gesv(GESV_PARAMS_t<fortran_doublecomplex> *params) 1723{ 1724 fortran_int rv; 1725 LAPACK(zgesv)(&params->N, &params->NRHS, 1726 params->A, &params->LDA, 1727 params->IPIV, 1728 params->B, &params->LDB, 1729 &rv); 1730 return rv; 1731} 1732 1733 1734/* 1735 * Initialize the parameters to use in for the lapack function _heev 1736 * Handles buffer allocation 1737 */ 1738template<typename ftyp> 1739static inline int 1740init_gesv(GESV_PARAMS_t<ftyp> *params, fortran_int N, fortran_int NRHS) 1741{ 1742 npy_uint8 *mem_buff = NULL; 1743 npy_uint8 *a, *b, *ipiv; 1744 size_t safe_N = N; 1745 size_t safe_NRHS = NRHS; 1746 fortran_int ld = fortran_int_max(N, 1); 1747 mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(ftyp) + 1748 safe_N * safe_NRHS*sizeof(ftyp) + 1749 safe_N * sizeof(fortran_int)); 1750 if (!mem_buff) { 1751 goto error; 1752 } 1753 a = mem_buff; 1754 b = a + safe_N * safe_N * sizeof(ftyp); 1755 ipiv = b + safe_N * safe_NRHS * sizeof(ftyp); 1756 1757 params->A = (ftyp*)a; 1758 params->B = (ftyp*)b; 1759 params->IPIV = (fortran_int*)ipiv; 1760 params->N = N; 1761 params->NRHS = NRHS; 1762 params->LDA = ld; 1763 params->LDB = ld; 1764 1765 return 1; 1766 error: 1767 free(mem_buff); 1768 memset(params, 0, sizeof(*params)); 1769 1770 return 0; 1771} 1772 1773template<typename ftyp> 1774static inline void 1775release_gesv(GESV_PARAMS_t<ftyp> *params) 1776{ 1777 /* memory block base is in A */ 1778 free(params->A); 1779 memset(params, 0, sizeof(*params)); 1780} 1781 1782template<typename typ> 1783static void 1784solve(char **args, npy_intp const *dimensions, npy_intp const *steps, 1785 void *NPY_UNUSED(func)) 1786{ 1787using ftyp = fortran_type_t<typ>; 1788 GESV_PARAMS_t<ftyp> params; 1789 fortran_int n, nrhs; 1790 int error_occurred = get_fp_invalid_and_clear(); 1791 INIT_OUTER_LOOP_3 1792 1793 n = (fortran_int)dimensions[0]; 1794 nrhs = (fortran_int)dimensions[1]; 1795 if (init_gesv(&params, n, nrhs)) { 1796 LINEARIZE_DATA_t a_in, b_in, r_out; 1797 1798 init_linearize_data(&a_in, n, n, steps[1], steps[0]); 1799 init_linearize_data(&b_in, nrhs, n, steps[3], steps[2]); 1800 init_linearize_data(&r_out, nrhs, n, steps[5], steps[4]); 1801 1802 BEGIN_OUTER_LOOP_3 1803 int not_ok; 1804 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 1805 linearize_matrix((typ*)params.B, (typ*)args[1], &b_in); 1806 not_ok =call_gesv(&params); 1807 if (!not_ok) { 1808 delinearize_matrix((typ*)args[2], (typ*)params.B, &r_out); 1809 } else { 1810 error_occurred = 1; 1811 nan_matrix((typ*)args[2], &r_out); 1812 } 1813 END_OUTER_LOOP 1814 1815 release_gesv(&params); 1816 } 1817 1818 set_fp_invalid_or_clear(error_occurred); 1819} 1820 1821 1822template<typename typ> 1823static void 1824solve1(char **args, npy_intp const *dimensions, npy_intp const *steps, 1825 void *NPY_UNUSED(func)) 1826{ 1827using ftyp = fortran_type_t<typ>; 1828 GESV_PARAMS_t<ftyp> params; 1829 int error_occurred = get_fp_invalid_and_clear(); 1830 fortran_int n; 1831 INIT_OUTER_LOOP_3 1832 1833 n = (fortran_int)dimensions[0]; 1834 if (init_gesv(&params, n, 1)) { 1835 LINEARIZE_DATA_t a_in, b_in, r_out; 1836 init_linearize_data(&a_in, n, n, steps[1], steps[0]); 1837 init_linearize_data(&b_in, 1, n, 1, steps[2]); 1838 init_linearize_data(&r_out, 1, n, 1, steps[3]); 1839 1840 BEGIN_OUTER_LOOP_3 1841 int not_ok; 1842 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 1843 linearize_matrix((typ*)params.B, (typ*)args[1], &b_in); 1844 not_ok = call_gesv(&params); 1845 if (!not_ok) { 1846 delinearize_matrix((typ*)args[2], (typ*)params.B, &r_out); 1847 } else { 1848 error_occurred = 1; 1849 nan_matrix((typ*)args[2], &r_out); 1850 } 1851 END_OUTER_LOOP 1852 1853 release_gesv(&params); 1854 } 1855 1856 set_fp_invalid_or_clear(error_occurred); 1857} 1858 1859template<typename typ> 1860static void 1861inv(char **args, npy_intp const *dimensions, npy_intp const *steps, 1862 void *NPY_UNUSED(func)) 1863{ 1864using ftyp = fortran_type_t<typ>; 1865 GESV_PARAMS_t<ftyp> params; 1866 fortran_int n; 1867 int error_occurred = get_fp_invalid_and_clear(); 1868 INIT_OUTER_LOOP_2 1869 1870 n = (fortran_int)dimensions[0]; 1871 if (init_gesv(&params, n, n)) { 1872 LINEARIZE_DATA_t a_in, r_out; 1873 init_linearize_data(&a_in, n, n, steps[1], steps[0]); 1874 init_linearize_data(&r_out, n, n, steps[3], steps[2]); 1875 1876 BEGIN_OUTER_LOOP_2 1877 int not_ok; 1878 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 1879 identity_matrix((typ*)params.B, n); 1880 not_ok = call_gesv(&params); 1881 if (!not_ok) { 1882 delinearize_matrix((typ*)args[1], (typ*)params.B, &r_out); 1883 } else { 1884 error_occurred = 1; 1885 nan_matrix((typ*)args[1], &r_out); 1886 } 1887 END_OUTER_LOOP 1888 1889 release_gesv(&params); 1890 } 1891 1892 set_fp_invalid_or_clear(error_occurred); 1893} 1894 1895 1896/* -------------------------------------------------------------------------- */ 1897 /* Cholesky decomposition */ 1898 1899template<typename typ> 1900struct POTR_PARAMS_t 1901{ 1902 typ *A; 1903 fortran_int N; 1904 fortran_int LDA; 1905 char UPLO; 1906}; 1907 1908 1909static inline fortran_int 1910call_potrf(POTR_PARAMS_t<fortran_real> *params) 1911{ 1912 fortran_int rv; 1913 LAPACK(spotrf)(&params->UPLO, 1914 &params->N, params->A, &params->LDA, 1915 &rv); 1916 return rv; 1917} 1918 1919static inline fortran_int 1920call_potrf(POTR_PARAMS_t<fortran_doublereal> *params) 1921{ 1922 fortran_int rv; 1923 LAPACK(dpotrf)(&params->UPLO, 1924 &params->N, params->A, &params->LDA, 1925 &rv); 1926 return rv; 1927} 1928 1929static inline fortran_int 1930call_potrf(POTR_PARAMS_t<fortran_complex> *params) 1931{ 1932 fortran_int rv; 1933 LAPACK(cpotrf)(&params->UPLO, 1934 &params->N, params->A, &params->LDA, 1935 &rv); 1936 return rv; 1937} 1938 1939static inline fortran_int 1940call_potrf(POTR_PARAMS_t<fortran_doublecomplex> *params) 1941{ 1942 fortran_int rv; 1943 LAPACK(zpotrf)(&params->UPLO, 1944 &params->N, params->A, &params->LDA, 1945 &rv); 1946 return rv; 1947} 1948 1949template<typename ftyp> 1950static inline int 1951init_potrf(POTR_PARAMS_t<ftyp> *params, char UPLO, fortran_int N) 1952{ 1953 npy_uint8 *mem_buff = NULL; 1954 npy_uint8 *a; 1955 size_t safe_N = N; 1956 fortran_int lda = fortran_int_max(N, 1); 1957 1958 mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(ftyp)); 1959 if (!mem_buff) { 1960 goto error; 1961 } 1962 1963 a = mem_buff; 1964 1965 params->A = (ftyp*)a; 1966 params->N = N; 1967 params->LDA = lda; 1968 params->UPLO = UPLO; 1969 1970 return 1; 1971 error: 1972 free(mem_buff); 1973 memset(params, 0, sizeof(*params)); 1974 1975 return 0; 1976} 1977 1978template<typename ftyp> 1979static inline void 1980release_potrf(POTR_PARAMS_t<ftyp> *params) 1981{ 1982 /* memory block base in A */ 1983 free(params->A); 1984 memset(params, 0, sizeof(*params)); 1985} 1986 1987template<typename typ> 1988static void 1989cholesky(char uplo, char **args, npy_intp const *dimensions, npy_intp const *steps) 1990{ 1991 using ftyp = fortran_type_t<typ>; 1992 POTR_PARAMS_t<ftyp> params; 1993 int error_occurred = get_fp_invalid_and_clear(); 1994 fortran_int n; 1995 INIT_OUTER_LOOP_2 1996 1997 assert(uplo == 'L'); 1998 1999 n = (fortran_int)dimensions[0]; 2000 if (init_potrf(&params, uplo, n)) { 2001 LINEARIZE_DATA_t a_in, r_out; 2002 init_linearize_data(&a_in, n, n, steps[1], steps[0]); 2003 init_linearize_data(&r_out, n, n, steps[3], steps[2]); 2004 BEGIN_OUTER_LOOP_2 2005 int not_ok; 2006 linearize_matrix(params.A, (ftyp*)args[0], &a_in); 2007 not_ok = call_potrf(&params); 2008 if (!not_ok) { 2009 triu_matrix(params.A, params.N); 2010 delinearize_matrix((ftyp*)args[1], params.A, &r_out); 2011 } else { 2012 error_occurred = 1; 2013 nan_matrix((ftyp*)args[1], &r_out); 2014 } 2015 END_OUTER_LOOP 2016 release_potrf(&params); 2017 } 2018 2019 set_fp_invalid_or_clear(error_occurred); 2020} 2021 2022template<typename typ> 2023static void 2024cholesky_lo(char **args, npy_intp const *dimensions, npy_intp const *steps, 2025 void *NPY_UNUSED(func)) 2026{ 2027 cholesky<typ>('L', args, dimensions, steps); 2028} 2029 2030/* -------------------------------------------------------------------------- */ 2031 /* eig family */ 2032 2033template<typename typ> 2034struct GEEV_PARAMS_t { 2035 typ *A; 2036 basetype_t<typ> *WR; /* RWORK in complex versions, REAL W buffer for (sd)geev*/ 2037 typ *WI; 2038 typ *VLR; /* REAL VL buffers for _geev where _ is s, d */ 2039 typ *VRR; /* REAL VR buffers for _geev where _ is s, d */ 2040 typ *WORK; 2041 typ *W; /* final w */ 2042 typ *VL; /* final vl */ 2043 typ *VR; /* final vr */ 2044 2045 fortran_int N; 2046 fortran_int LDA; 2047 fortran_int LDVL; 2048 fortran_int LDVR; 2049 fortran_int LWORK; 2050 2051 char JOBVL; 2052 char JOBVR; 2053}; 2054 2055template<typename typ> 2056static inline void 2057dump_geev_params(const char *name, GEEV_PARAMS_t<typ>* params) 2058{ 2059 TRACE_TXT("\n%s\n" 2060 2061 "\t%10s: %p\n"\ 2062 "\t%10s: %p\n"\ 2063 "\t%10s: %p\n"\ 2064 "\t%10s: %p\n"\ 2065 "\t%10s: %p\n"\ 2066 "\t%10s: %p\n"\ 2067 "\t%10s: %p\n"\ 2068 "\t%10s: %p\n"\ 2069 "\t%10s: %p\n"\ 2070 2071 "\t%10s: %d\n"\ 2072 "\t%10s: %d\n"\ 2073 "\t%10s: %d\n"\ 2074 "\t%10s: %d\n"\ 2075 "\t%10s: %d\n"\ 2076 2077 "\t%10s: %c\n"\ 2078 "\t%10s: %c\n", 2079 2080 name, 2081 2082 "A", params->A, 2083 "WR", params->WR, 2084 "WI", params->WI, 2085 "VLR", params->VLR, 2086 "VRR", params->VRR, 2087 "WORK", params->WORK, 2088 "W", params->W, 2089 "VL", params->VL, 2090 "VR", params->VR, 2091 2092 "N", (int)params->N, 2093 "LDA", (int)params->LDA, 2094 "LDVL", (int)params->LDVL, 2095 "LDVR", (int)params->LDVR, 2096 "LWORK", (int)params->LWORK, 2097 2098 "JOBVL", params->JOBVL, 2099 "JOBVR", params->JOBVR); 2100} 2101 2102static inline fortran_int 2103call_geev(GEEV_PARAMS_t<float>* params) 2104{ 2105 fortran_int rv; 2106 LAPACK(sgeev)(&params->JOBVL, &params->JOBVR, 2107 &params->N, params->A, &params->LDA, 2108 params->WR, params->WI, 2109 params->VLR, &params->LDVL, 2110 params->VRR, &params->LDVR, 2111 params->WORK, &params->LWORK, 2112 &rv); 2113 return rv; 2114} 2115 2116static inline fortran_int 2117call_geev(GEEV_PARAMS_t<double>* params) 2118{ 2119 fortran_int rv; 2120 LAPACK(dgeev)(&params->JOBVL, &params->JOBVR, 2121 &params->N, params->A, &params->LDA, 2122 params->WR, params->WI, 2123 params->VLR, &params->LDVL, 2124 params->VRR, &params->LDVR, 2125 params->WORK, &params->LWORK, 2126 &rv); 2127 return rv; 2128} 2129 2130 2131template<typename typ> 2132static inline int 2133init_geev(GEEV_PARAMS_t<typ> *params, char jobvl, char jobvr, fortran_int n, 2134scalar_trait) 2135{ 2136 npy_uint8 *mem_buff = NULL; 2137 npy_uint8 *mem_buff2 = NULL; 2138 npy_uint8 *a, *wr, *wi, *vlr, *vrr, *work, *w, *vl, *vr; 2139 size_t safe_n = n; 2140 size_t a_size = safe_n * safe_n * sizeof(typ); 2141 size_t wr_size = safe_n * sizeof(typ); 2142 size_t wi_size = safe_n * sizeof(typ); 2143 size_t vlr_size = jobvl=='V' ? safe_n * safe_n * sizeof(typ) : 0; 2144 size_t vrr_size = jobvr=='V' ? safe_n * safe_n * sizeof(typ) : 0; 2145 size_t w_size = wr_size*2; 2146 size_t vl_size = vlr_size*2; 2147 size_t vr_size = vrr_size*2; 2148 size_t work_count = 0; 2149 fortran_int ld = fortran_int_max(n, 1); 2150 2151 /* allocate data for known sizes (all but work) */ 2152 mem_buff = (npy_uint8 *)malloc(a_size + wr_size + wi_size + 2153 vlr_size + vrr_size + 2154 w_size + vl_size + vr_size); 2155 if (!mem_buff) { 2156 goto error; 2157 } 2158 2159 a = mem_buff; 2160 wr = a + a_size; 2161 wi = wr + wr_size; 2162 vlr = wi + wi_size; 2163 vrr = vlr + vlr_size; 2164 w = vrr + vrr_size; 2165 vl = w + w_size; 2166 vr = vl + vl_size; 2167 2168 params->A = (typ*)a; 2169 params->WR = (typ*)wr; 2170 params->WI = (typ*)wi; 2171 params->VLR = (typ*)vlr; 2172 params->VRR = (typ*)vrr; 2173 params->W = (typ*)w; 2174 params->VL = (typ*)vl; 2175 params->VR = (typ*)vr; 2176 params->N = n; 2177 params->LDA = ld; 2178 params->LDVL = ld; 2179 params->LDVR = ld; 2180 params->JOBVL = jobvl; 2181 params->JOBVR = jobvr; 2182 2183 /* Work size query */ 2184 { 2185 typ work_size_query; 2186 2187 params->LWORK = -1; 2188 params->WORK = &work_size_query; 2189 2190 if (call_geev(params) != 0) { 2191 goto error; 2192 } 2193 2194 work_count = (size_t)work_size_query; 2195 } 2196 2197 mem_buff2 = (npy_uint8 *)malloc(work_count*sizeof(typ)); 2198 if (!mem_buff2) { 2199 goto error; 2200 } 2201 work = mem_buff2; 2202 2203 params->LWORK = (fortran_int)work_count; 2204 params->WORK = (typ*)work; 2205 2206 return 1; 2207 error: 2208 free(mem_buff2); 2209 free(mem_buff); 2210 memset(params, 0, sizeof(*params)); 2211 2212 return 0; 2213} 2214 2215template<typename complextyp, typename typ> 2216static inline void 2217mk_complex_array_from_real(complextyp *c, const typ *re, size_t n) 2218{ 2219 size_t iter; 2220 for (iter = 0; iter < n; ++iter) { 2221 c[iter].r = re[iter]; 2222 c[iter].i = numeric_limits<typ>::zero; 2223 } 2224} 2225 2226template<typename complextyp, typename typ> 2227static inline void 2228mk_complex_array(complextyp *c, 2229 const typ *re, 2230 const typ *im, 2231 size_t n) 2232{ 2233 size_t iter; 2234 for (iter = 0; iter < n; ++iter) { 2235 c[iter].r = re[iter]; 2236 c[iter].i = im[iter]; 2237 } 2238} 2239 2240template<typename complextyp, typename typ> 2241static inline void 2242mk_complex_array_conjugate_pair(complextyp *c, 2243 const typ *r, 2244 size_t n) 2245{ 2246 size_t iter; 2247 for (iter = 0; iter < n; ++iter) { 2248 typ re = r[iter]; 2249 typ im = r[iter+n]; 2250 c[iter].r = re; 2251 c[iter].i = im; 2252 c[iter+n].r = re; 2253 c[iter+n].i = -im; 2254 } 2255} 2256 2257/* 2258 * make the complex eigenvectors from the real array produced by sgeev/zgeev. 2259 * c is the array where the results will be left. 2260 * r is the source array of reals produced by sgeev/zgeev 2261 * i is the eigenvalue imaginary part produced by sgeev/zgeev 2262 * n is so that the order of the matrix is n by n 2263 */ 2264template<typename complextyp, typename typ> 2265static inline void 2266mk_geev_complex_eigenvectors(complextyp *c, 2267 const typ *r, 2268 const typ *i, 2269 size_t n) 2270{ 2271 size_t iter = 0; 2272 while (iter < n) 2273 { 2274 if (i[iter] == numeric_limits<typ>::zero) { 2275 /* eigenvalue was real, eigenvectors as well... */ 2276 mk_complex_array_from_real(c, r, n); 2277 c += n; 2278 r += n; 2279 iter ++; 2280 } else { 2281 /* eigenvalue was complex, generate a pair of eigenvectors */ 2282 mk_complex_array_conjugate_pair(c, r, n); 2283 c += 2*n; 2284 r += 2*n; 2285 iter += 2; 2286 } 2287 } 2288} 2289 2290 2291template<typename complextyp, typename typ> 2292static inline void 2293process_geev_results(GEEV_PARAMS_t<typ> *params, scalar_trait) 2294{ 2295 /* REAL versions of geev need the results to be translated 2296 * into complex versions. This is the way to deal with imaginary 2297 * results. In our gufuncs we will always return complex arrays! 2298 */ 2299 mk_complex_array((complextyp*)params->W, (typ*)params->WR, (typ*)params->WI, params->N); 2300 2301 /* handle the eigenvectors */ 2302 if ('V' == params->JOBVL) { 2303 mk_geev_complex_eigenvectors((complextyp*)params->VL, (typ*)params->VLR, 2304 (typ*)params->WI, params->N); 2305 } 2306 if ('V' == params->JOBVR) { 2307 mk_geev_complex_eigenvectors((complextyp*)params->VR, (typ*)params->VRR, 2308 (typ*)params->WI, params->N); 2309 } 2310} 2311 2312 2313static inline fortran_int 2314call_geev(GEEV_PARAMS_t<fortran_complex>* params) 2315{ 2316 fortran_int rv; 2317 2318 LAPACK(cgeev)(&params->JOBVL, &params->JOBVR, 2319 &params->N, params->A, &params->LDA, 2320 params->W, 2321 params->VL, &params->LDVL, 2322 params->VR, &params->LDVR, 2323 params->WORK, &params->LWORK, 2324 params->WR, /* actually RWORK */ 2325 &rv); 2326 return rv; 2327} 2328static inline fortran_int 2329call_geev(GEEV_PARAMS_t<fortran_doublecomplex>* params) 2330{ 2331 fortran_int rv; 2332 2333 LAPACK(zgeev)(&params->JOBVL, &params->JOBVR, 2334 &params->N, params->A, &params->LDA, 2335 params->W, 2336 params->VL, &params->LDVL, 2337 params->VR, &params->LDVR, 2338 params->WORK, &params->LWORK, 2339 params->WR, /* actually RWORK */ 2340 &rv); 2341 return rv; 2342} 2343 2344template<typename ftyp> 2345static inline int 2346init_geev(GEEV_PARAMS_t<ftyp>* params, 2347 char jobvl, 2348 char jobvr, 2349 fortran_int n, complex_trait) 2350{ 2351using realtyp = basetype_t<ftyp>; 2352 npy_uint8 *mem_buff = NULL; 2353 npy_uint8 *mem_buff2 = NULL; 2354 npy_uint8 *a, *w, *vl, *vr, *work, *rwork; 2355 size_t safe_n = n; 2356 size_t a_size = safe_n * safe_n * sizeof(ftyp); 2357 size_t w_size = safe_n * sizeof(ftyp); 2358 size_t vl_size = jobvl=='V'? safe_n * safe_n * sizeof(ftyp) : 0; 2359 size_t vr_size = jobvr=='V'? safe_n * safe_n * sizeof(ftyp) : 0; 2360 size_t rwork_size = 2 * safe_n * sizeof(realtyp); 2361 size_t work_count = 0; 2362 size_t total_size = a_size + w_size + vl_size + vr_size + rwork_size; 2363 fortran_int ld = fortran_int_max(n, 1); 2364 2365 mem_buff = (npy_uint8 *)malloc(total_size); 2366 if (!mem_buff) { 2367 goto error; 2368 } 2369 2370 a = mem_buff; 2371 w = a + a_size; 2372 vl = w + w_size; 2373 vr = vl + vl_size; 2374 rwork = vr + vr_size; 2375 2376 params->A = (ftyp*)a; 2377 params->WR = (realtyp*)rwork; 2378 params->WI = NULL; 2379 params->VLR = NULL; 2380 params->VRR = NULL; 2381 params->VL = (ftyp*)vl; 2382 params->VR = (ftyp*)vr; 2383 params->W = (ftyp*)w; 2384 params->N = n; 2385 params->LDA = ld; 2386 params->LDVL = ld; 2387 params->LDVR = ld; 2388 params->JOBVL = jobvl; 2389 params->JOBVR = jobvr; 2390 2391 /* Work size query */ 2392 { 2393 ftyp work_size_query; 2394 2395 params->LWORK = -1; 2396 params->WORK = &work_size_query; 2397 2398 if (call_geev(params) != 0) { 2399 goto error; 2400 } 2401 2402 work_count = (size_t) work_size_query.r; 2403 /* Fix a bug in lapack 3.0.0 */ 2404 if(work_count == 0) work_count = 1; 2405 } 2406 2407 mem_buff2 = (npy_uint8 *)malloc(work_count*sizeof(ftyp)); 2408 if (!mem_buff2) { 2409 goto error; 2410 } 2411 2412 work = mem_buff2; 2413 2414 params->LWORK = (fortran_int)work_count; 2415 params->WORK = (ftyp*)work; 2416 2417 return 1; 2418 error: 2419 free(mem_buff2); 2420 free(mem_buff); 2421 memset(params, 0, sizeof(*params)); 2422 2423 return 0; 2424} 2425 2426template<typename complextyp, typename typ> 2427static inline void 2428process_geev_results(GEEV_PARAMS_t<typ> *NPY_UNUSED(params), complex_trait) 2429{ 2430 /* nothing to do here, complex versions are ready to copy out */ 2431} 2432 2433 2434 2435template<typename typ> 2436static inline void 2437release_geev(GEEV_PARAMS_t<typ> *params) 2438{ 2439 free(params->WORK); 2440 free(params->A); 2441 memset(params, 0, sizeof(*params)); 2442} 2443 2444template<typename fctype, typename ftype> 2445static inline void 2446eig_wrapper(char JOBVL, 2447 char JOBVR, 2448 char**args, 2449 npy_intp const *dimensions, 2450 npy_intp const *steps) 2451{ 2452 ptrdiff_t outer_steps[4]; 2453 size_t iter; 2454 size_t outer_dim = *dimensions++; 2455 size_t op_count = 2; 2456 int error_occurred = get_fp_invalid_and_clear(); 2457 GEEV_PARAMS_t<ftype> geev_params; 2458 2459 assert(JOBVL == 'N'); 2460 2461 STACK_TRACE; 2462 op_count += 'V'==JOBVL?1:0; 2463 op_count += 'V'==JOBVR?1:0; 2464 2465 for (iter = 0; iter < op_count; ++iter) { 2466 outer_steps[iter] = (ptrdiff_t) steps[iter]; 2467 } 2468 steps += op_count; 2469 2470 if (init_geev(&geev_params, 2471 JOBVL, JOBVR, 2472 (fortran_int)dimensions[0], dispatch_scalar<ftype>())) { 2473 LINEARIZE_DATA_t a_in; 2474 LINEARIZE_DATA_t w_out; 2475 LINEARIZE_DATA_t vl_out; 2476 LINEARIZE_DATA_t vr_out; 2477 2478 init_linearize_data(&a_in, 2479 geev_params.N, geev_params.N, 2480 steps[1], steps[0]); 2481 steps += 2; 2482 init_linearize_data(&w_out, 2483 1, geev_params.N, 2484 0, steps[0]); 2485 steps += 1; 2486 if ('V' == geev_params.JOBVL) { 2487 init_linearize_data(&vl_out, 2488 geev_params.N, geev_params.N, 2489 steps[1], steps[0]); 2490 steps += 2; 2491 } 2492 if ('V' == geev_params.JOBVR) { 2493 init_linearize_data(&vr_out, 2494 geev_params.N, geev_params.N, 2495 steps[1], steps[0]); 2496 } 2497 2498 for (iter = 0; iter < outer_dim; ++iter) { 2499 int not_ok; 2500 char **arg_iter = args; 2501 /* copy the matrix in */ 2502 linearize_matrix((ftype*)geev_params.A, (ftype*)*arg_iter++, &a_in); 2503 not_ok = call_geev(&geev_params); 2504 2505 if (!not_ok) { 2506 process_geev_results<fctype>(&geev_params, 2507dispatch_scalar<ftype>{}); 2508 delinearize_matrix((fctype*)*arg_iter++, 2509 (fctype*)geev_params.W, 2510 &w_out); 2511 2512 if ('V' == geev_params.JOBVL) { 2513 delinearize_matrix((fctype*)*arg_iter++, 2514 (fctype*)geev_params.VL, 2515 &vl_out); 2516 } 2517 if ('V' == geev_params.JOBVR) { 2518 delinearize_matrix((fctype*)*arg_iter++, 2519 (fctype*)geev_params.VR, 2520 &vr_out); 2521 } 2522 } else { 2523 /* geev failed */ 2524 error_occurred = 1; 2525 nan_matrix((fctype*)*arg_iter++, &w_out); 2526 if ('V' == geev_params.JOBVL) { 2527 nan_matrix((fctype*)*arg_iter++, &vl_out); 2528 } 2529 if ('V' == geev_params.JOBVR) { 2530 nan_matrix((fctype*)*arg_iter++, &vr_out); 2531 } 2532 } 2533 update_pointers((npy_uint8**)args, outer_steps, op_count); 2534 } 2535 2536 release_geev(&geev_params); 2537 } 2538 2539 set_fp_invalid_or_clear(error_occurred); 2540} 2541 2542template<typename fctype, typename ftype> 2543static void 2544eig(char **args, 2545 npy_intp const *dimensions, 2546 npy_intp const *steps, 2547 void *NPY_UNUSED(func)) 2548{ 2549 eig_wrapper<fctype, ftype>('N', 'V', args, dimensions, steps); 2550} 2551 2552template<typename fctype, typename ftype> 2553static void 2554eigvals(char **args, 2555 npy_intp const *dimensions, 2556 npy_intp const *steps, 2557 void *NPY_UNUSED(func)) 2558{ 2559 eig_wrapper<fctype, ftype>('N', 'N', args, dimensions, steps); 2560} 2561 2562 2563 2564/* -------------------------------------------------------------------------- */ 2565 /* singular value decomposition */ 2566 2567template<typename ftyp> 2568struct GESDD_PARAMS_t 2569{ 2570 ftyp *A; 2571 basetype_t<ftyp> *S; 2572 ftyp *U; 2573 ftyp *VT; 2574 ftyp *WORK; 2575 basetype_t<ftyp> *RWORK; 2576 fortran_int *IWORK; 2577 2578 fortran_int M; 2579 fortran_int N; 2580 fortran_int LDA; 2581 fortran_int LDU; 2582 fortran_int LDVT; 2583 fortran_int LWORK; 2584 char JOBZ; 2585} ; 2586 2587 2588template<typename ftyp> 2589static inline void 2590dump_gesdd_params(const char *name, 2591 GESDD_PARAMS_t<ftyp> *params) 2592{ 2593 TRACE_TXT("\n%s:\n"\ 2594 2595 "%14s: %18p\n"\ 2596 "%14s: %18p\n"\ 2597 "%14s: %18p\n"\ 2598 "%14s: %18p\n"\ 2599 "%14s: %18p\n"\ 2600 "%14s: %18p\n"\ 2601 "%14s: %18p\n"\ 2602 2603 "%14s: %18d\n"\ 2604 "%14s: %18d\n"\ 2605 "%14s: %18d\n"\ 2606 "%14s: %18d\n"\ 2607 "%14s: %18d\n"\ 2608 "%14s: %18d\n"\ 2609 2610 "%14s: %15c'%c'\n", 2611 2612 name, 2613 2614 "A", params->A, 2615 "S", params->S, 2616 "U", params->U, 2617 "VT", params->VT, 2618 "WORK", params->WORK, 2619 "RWORK", params->RWORK, 2620 "IWORK", params->IWORK, 2621 2622 "M", (int)params->M, 2623 "N", (int)params->N, 2624 "LDA", (int)params->LDA, 2625 "LDU", (int)params->LDU, 2626 "LDVT", (int)params->LDVT, 2627 "LWORK", (int)params->LWORK, 2628 2629 "JOBZ", ' ', params->JOBZ); 2630} 2631 2632static inline int 2633compute_urows_vtcolumns(char jobz, 2634 fortran_int m, fortran_int n, 2635 fortran_int *urows, fortran_int *vtcolumns) 2636{ 2637 fortran_int min_m_n = fortran_int_min(m, n); 2638 switch(jobz) 2639 { 2640 case 'N': 2641 *urows = 0; 2642 *vtcolumns = 0; 2643 break; 2644 case 'A': 2645 *urows = m; 2646 *vtcolumns = n; 2647 break; 2648 case 'S': 2649 { 2650 *urows = min_m_n; 2651 *vtcolumns = min_m_n; 2652 } 2653 break; 2654 default: 2655 return 0; 2656 } 2657 2658 return 1; 2659} 2660 2661static inline fortran_int 2662call_gesdd(GESDD_PARAMS_t<fortran_real> *params) 2663{ 2664 fortran_int rv; 2665 LAPACK(sgesdd)(&params->JOBZ, &params->M, &params->N, 2666 params->A, &params->LDA, 2667 params->S, 2668 params->U, &params->LDU, 2669 params->VT, &params->LDVT, 2670 params->WORK, &params->LWORK, 2671 (fortran_int*)params->IWORK, 2672 &rv); 2673 return rv; 2674} 2675static inline fortran_int 2676call_gesdd(GESDD_PARAMS_t<fortran_doublereal> *params) 2677{ 2678 fortran_int rv; 2679 LAPACK(dgesdd)(&params->JOBZ, &params->M, &params->N, 2680 params->A, &params->LDA, 2681 params->S, 2682 params->U, &params->LDU, 2683 params->VT, &params->LDVT, 2684 params->WORK, &params->LWORK, 2685 (fortran_int*)params->IWORK, 2686 &rv); 2687 return rv; 2688} 2689 2690template<typename ftyp> 2691static inline int 2692init_gesdd(GESDD_PARAMS_t<ftyp> *params, 2693 char jobz, 2694 fortran_int m, 2695 fortran_int n, scalar_trait) 2696{ 2697 npy_uint8 *mem_buff = NULL; 2698 npy_uint8 *mem_buff2 = NULL; 2699 npy_uint8 *a, *s, *u, *vt, *work, *iwork; 2700 size_t safe_m = m; 2701 size_t safe_n = n; 2702 size_t a_size = safe_m * safe_n * sizeof(ftyp); 2703 fortran_int min_m_n = fortran_int_min(m, n); 2704 size_t safe_min_m_n = min_m_n; 2705 size_t s_size = safe_min_m_n * sizeof(ftyp); 2706 fortran_int u_row_count, vt_column_count; 2707 size_t safe_u_row_count, safe_vt_column_count; 2708 size_t u_size, vt_size; 2709 fortran_int work_count; 2710 size_t work_size; 2711 size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int); 2712 fortran_int ld = fortran_int_max(m, 1); 2713 2714 if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) { 2715 goto error; 2716 } 2717 2718 safe_u_row_count = u_row_count; 2719 safe_vt_column_count = vt_column_count; 2720 2721 u_size = safe_u_row_count * safe_m * sizeof(ftyp); 2722 vt_size = safe_n * safe_vt_column_count * sizeof(ftyp); 2723 2724 mem_buff = (npy_uint8 *)malloc(a_size + s_size + u_size + vt_size + iwork_size); 2725 2726 if (!mem_buff) { 2727 goto error; 2728 } 2729 2730 a = mem_buff; 2731 s = a + a_size; 2732 u = s + s_size; 2733 vt = u + u_size; 2734 iwork = vt + vt_size; 2735 2736 /* fix vt_column_count so that it is a valid lapack parameter (0 is not) */ 2737 vt_column_count = fortran_int_max(1, vt_column_count); 2738 2739 params->M = m; 2740 params->N = n; 2741 params->A = (ftyp*)a; 2742 params->S = (ftyp*)s; 2743 params->U = (ftyp*)u; 2744 params->VT = (ftyp*)vt; 2745 params->RWORK = NULL; 2746 params->IWORK = (fortran_int*)iwork; 2747 params->LDA = ld; 2748 params->LDU = ld; 2749 params->LDVT = vt_column_count; 2750 params->JOBZ = jobz; 2751 2752 /* Work size query */ 2753 { 2754 ftyp work_size_query; 2755 2756 params->LWORK = -1; 2757 params->WORK = &work_size_query; 2758 2759 if (call_gesdd(params) != 0) { 2760 goto error; 2761 } 2762 2763 work_count = (fortran_int)work_size_query; 2764 /* Fix a bug in lapack 3.0.0 */ 2765 if(work_count == 0) work_count = 1; 2766 work_size = (size_t)work_count * sizeof(ftyp); 2767 } 2768 2769 mem_buff2 = (npy_uint8 *)malloc(work_size); 2770 if (!mem_buff2) { 2771 goto error; 2772 } 2773 2774 work = mem_buff2; 2775 2776 params->LWORK = work_count; 2777 params->WORK = (ftyp*)work; 2778 2779 return 1; 2780 error: 2781 TRACE_TXT("%s failed init\n", __FUNCTION__); 2782 free(mem_buff); 2783 free(mem_buff2); 2784 memset(params, 0, sizeof(*params)); 2785 2786 return 0; 2787} 2788 2789static inline fortran_int 2790call_gesdd(GESDD_PARAMS_t<fortran_complex> *params) 2791{ 2792 fortran_int rv; 2793 LAPACK(cgesdd)(&params->JOBZ, &params->M, &params->N, 2794 params->A, &params->LDA, 2795 params->S, 2796 params->U, &params->LDU, 2797 params->VT, &params->LDVT, 2798 params->WORK, &params->LWORK, 2799 params->RWORK, 2800 params->IWORK, 2801 &rv); 2802 return rv; 2803} 2804static inline fortran_int 2805call_gesdd(GESDD_PARAMS_t<fortran_doublecomplex> *params) 2806{ 2807 fortran_int rv; 2808 LAPACK(zgesdd)(&params->JOBZ, &params->M, &params->N, 2809 params->A, &params->LDA, 2810 params->S, 2811 params->U, &params->LDU, 2812 params->VT, &params->LDVT, 2813 params->WORK, &params->LWORK, 2814 params->RWORK, 2815 params->IWORK, 2816 &rv); 2817 return rv; 2818} 2819 2820template<typename ftyp> 2821static inline int 2822init_gesdd(GESDD_PARAMS_t<ftyp> *params, 2823 char jobz, 2824 fortran_int m, 2825 fortran_int n, complex_trait) 2826{ 2827using frealtyp = basetype_t<ftyp>; 2828 npy_uint8 *mem_buff = NULL, *mem_buff2 = NULL; 2829 npy_uint8 *a,*s, *u, *vt, *work, *rwork, *iwork; 2830 size_t a_size, s_size, u_size, vt_size, work_size, rwork_size, iwork_size; 2831 size_t safe_u_row_count, safe_vt_column_count; 2832 fortran_int u_row_count, vt_column_count, work_count; 2833 size_t safe_m = m; 2834 size_t safe_n = n; 2835 fortran_int min_m_n = fortran_int_min(m, n); 2836 size_t safe_min_m_n = min_m_n; 2837 fortran_int ld = fortran_int_max(m, 1); 2838 2839 if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) { 2840 goto error; 2841 } 2842 2843 safe_u_row_count = u_row_count; 2844 safe_vt_column_count = vt_column_count; 2845 2846 a_size = safe_m * safe_n * sizeof(ftyp); 2847 s_size = safe_min_m_n * sizeof(frealtyp); 2848 u_size = safe_u_row_count * safe_m * sizeof(ftyp); 2849 vt_size = safe_n * safe_vt_column_count * sizeof(ftyp); 2850 rwork_size = 'N'==jobz? 2851 (7 * safe_min_m_n) : 2852 (5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n); 2853 rwork_size *= sizeof(ftyp); 2854 iwork_size = 8 * safe_min_m_n* sizeof(fortran_int); 2855 2856 mem_buff = (npy_uint8 *)malloc(a_size + 2857 s_size + 2858 u_size + 2859 vt_size + 2860 rwork_size + 2861 iwork_size); 2862 if (!mem_buff) { 2863 goto error; 2864 } 2865 2866 a = mem_buff; 2867 s = a + a_size; 2868 u = s + s_size; 2869 vt = u + u_size; 2870 rwork = vt + vt_size; 2871 iwork = rwork + rwork_size; 2872 2873 /* fix vt_column_count so that it is a valid lapack parameter (0 is not) */ 2874 vt_column_count = fortran_int_max(1, vt_column_count); 2875 2876 params->A = (ftyp*)a; 2877 params->S = (frealtyp*)s; 2878 params->U = (ftyp*)u; 2879 params->VT = (ftyp*)vt; 2880 params->RWORK = (frealtyp*)rwork; 2881 params->IWORK = (fortran_int*)iwork; 2882 params->M = m; 2883 params->N = n; 2884 params->LDA = ld; 2885 params->LDU = ld; 2886 params->LDVT = vt_column_count; 2887 params->JOBZ = jobz; 2888 2889 /* Work size query */ 2890 { 2891 ftyp work_size_query; 2892 2893 params->LWORK = -1; 2894 params->WORK = &work_size_query; 2895 2896 if (call_gesdd(params) != 0) { 2897 goto error; 2898 } 2899 2900 work_count = (fortran_int)(*(frealtyp*)&work_size_query); 2901 /* Fix a bug in lapack 3.0.0 */ 2902 if(work_count == 0) work_count = 1; 2903 work_size = (size_t)work_count * sizeof(ftyp); 2904 } 2905 2906 mem_buff2 = (npy_uint8 *)malloc(work_size); 2907 if (!mem_buff2) { 2908 goto error; 2909 } 2910 2911 work = mem_buff2; 2912 2913 params->LWORK = work_count; 2914 params->WORK = (ftyp*)work; 2915 2916 return 1; 2917 error: 2918 TRACE_TXT("%s failed init\n", __FUNCTION__); 2919 free(mem_buff2); 2920 free(mem_buff); 2921 memset(params, 0, sizeof(*params)); 2922 2923 return 0; 2924} 2925 2926template<typename typ> 2927static inline void 2928release_gesdd(GESDD_PARAMS_t<typ>* params) 2929{ 2930 /* A and WORK contain allocated blocks */ 2931 free(params->A); 2932 free(params->WORK); 2933 memset(params, 0, sizeof(*params)); 2934} 2935 2936template<typename typ> 2937static inline void 2938svd_wrapper(char JOBZ, 2939 char **args, 2940 npy_intp const *dimensions, 2941 npy_intp const *steps) 2942{ 2943using basetyp = basetype_t<typ>; 2944 ptrdiff_t outer_steps[4]; 2945 int error_occurred = get_fp_invalid_and_clear(); 2946 size_t iter; 2947 size_t outer_dim = *dimensions++; 2948 size_t op_count = (JOBZ=='N')?2:4; 2949 GESDD_PARAMS_t<typ> params; 2950 2951 for (iter = 0; iter < op_count; ++iter) { 2952 outer_steps[iter] = (ptrdiff_t) steps[iter]; 2953 } 2954 steps += op_count; 2955 2956 if (init_gesdd(&params, 2957 JOBZ, 2958 (fortran_int)dimensions[0], 2959 (fortran_int)dimensions[1], 2960dispatch_scalar<typ>())) { 2961 LINEARIZE_DATA_t a_in, u_out, s_out, v_out; 2962 fortran_int min_m_n = params.M < params.N ? params.M : params.N; 2963 2964 init_linearize_data(&a_in, params.N, params.M, steps[1], steps[0]); 2965 if ('N' == params.JOBZ) { 2966 /* only the singular values are wanted */ 2967 init_linearize_data(&s_out, 1, min_m_n, 0, steps[2]); 2968 } else { 2969 fortran_int u_columns, v_rows; 2970 if ('S' == params.JOBZ) { 2971 u_columns = min_m_n; 2972 v_rows = min_m_n; 2973 } else { /* JOBZ == 'A' */ 2974 u_columns = params.M; 2975 v_rows = params.N; 2976 } 2977 init_linearize_data(&u_out, 2978 u_columns, params.M, 2979 steps[3], steps[2]); 2980 init_linearize_data(&s_out, 2981 1, min_m_n, 2982 0, steps[4]); 2983 init_linearize_data(&v_out, 2984 params.N, v_rows, 2985 steps[6], steps[5]); 2986 } 2987 2988 for (iter = 0; iter < outer_dim; ++iter) { 2989 int not_ok; 2990 /* copy the matrix in */ 2991 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 2992 not_ok = call_gesdd(&params); 2993 if (!not_ok) { 2994 if ('N' == params.JOBZ) { 2995 delinearize_matrix((basetyp*)args[1], (basetyp*)params.S, &s_out); 2996 } else { 2997 if ('A' == params.JOBZ && min_m_n == 0) { 2998 /* Lapack has betrayed us and left these uninitialized, 2999 * so produce an identity matrix for whichever of u 3000 * and v is not empty. 3001 */ 3002 identity_matrix((typ*)params.U, params.M); 3003 identity_matrix((typ*)params.VT, params.N); 3004 } 3005 3006 delinearize_matrix((typ*)args[1], (typ*)params.U, &u_out); 3007 delinearize_matrix((basetyp*)args[2], (basetyp*)params.S, &s_out); 3008 delinearize_matrix((typ*)args[3], (typ*)params.VT, &v_out); 3009 } 3010 } else { 3011 error_occurred = 1; 3012 if ('N' == params.JOBZ) { 3013 nan_matrix((basetyp*)args[1], &s_out); 3014 } else { 3015 nan_matrix((typ*)args[1], &u_out); 3016 nan_matrix((basetyp*)args[2], &s_out); 3017 nan_matrix((typ*)args[3], &v_out); 3018 } 3019 } 3020 update_pointers((npy_uint8**)args, outer_steps, op_count); 3021 } 3022 3023 release_gesdd(&params); 3024 } 3025 3026 set_fp_invalid_or_clear(error_occurred); 3027} 3028 3029 3030template<typename typ> 3031static void 3032svd_N(char **args, 3033 npy_intp const *dimensions, 3034 npy_intp const *steps, 3035 void *NPY_UNUSED(func)) 3036{ 3037 svd_wrapper<fortran_type_t<typ>>('N', args, dimensions, steps); 3038} 3039 3040template<typename typ> 3041static void 3042svd_S(char **args, 3043 npy_intp const *dimensions, 3044 npy_intp const *steps, 3045 void *NPY_UNUSED(func)) 3046{ 3047 svd_wrapper<fortran_type_t<typ>>('S', args, dimensions, steps); 3048} 3049 3050template<typename typ> 3051static void 3052svd_A(char **args, 3053 npy_intp const *dimensions, 3054 npy_intp const *steps, 3055 void *NPY_UNUSED(func)) 3056{ 3057 svd_wrapper<fortran_type_t<typ>>('A', args, dimensions, steps); 3058} 3059 3060/* -------------------------------------------------------------------------- */ 3061 /* qr (modes - r, raw) */ 3062 3063template<typename typ> 3064struct GEQRF_PARAMS_t 3065{ 3066 fortran_int M; 3067 fortran_int N; 3068 typ *A; 3069 fortran_int LDA; 3070 typ* TAU; 3071 typ *WORK; 3072 fortran_int LWORK; 3073}; 3074 3075 3076template<typename typ> 3077static inline void 3078dump_geqrf_params(const char *name, 3079 GEQRF_PARAMS_t<typ> *params) 3080{ 3081 TRACE_TXT("\n%s:\n"\ 3082 3083 "%14s: %18p\n"\ 3084 "%14s: %18p\n"\ 3085 "%14s: %18p\n"\ 3086 "%14s: %18d\n"\ 3087 "%14s: %18d\n"\ 3088 "%14s: %18d\n"\ 3089 "%14s: %18d\n", 3090 3091 name, 3092 3093 "A", params->A, 3094 "TAU", params->TAU, 3095 "WORK", params->WORK, 3096 3097 "M", (int)params->M, 3098 "N", (int)params->N, 3099 "LDA", (int)params->LDA, 3100 "LWORK", (int)params->LWORK); 3101} 3102 3103static inline fortran_int 3104call_geqrf(GEQRF_PARAMS_t<double> *params) 3105{ 3106 fortran_int rv; 3107 LAPACK(dgeqrf)(&params->M, &params->N, 3108 params->A, &params->LDA, 3109 params->TAU, 3110 params->WORK, &params->LWORK, 3111 &rv); 3112 return rv; 3113} 3114static inline fortran_int 3115call_geqrf(GEQRF_PARAMS_t<f2c_doublecomplex> *params) 3116{ 3117 fortran_int rv; 3118 LAPACK(zgeqrf)(&params->M, &params->N, 3119 params->A, &params->LDA, 3120 params->TAU, 3121 params->WORK, &params->LWORK, 3122 &rv); 3123 return rv; 3124} 3125 3126 3127static inline int 3128init_geqrf(GEQRF_PARAMS_t<fortran_doublereal> *params, 3129 fortran_int m, 3130 fortran_int n) 3131{ 3132using ftyp = fortran_doublereal; 3133 npy_uint8 *mem_buff = NULL; 3134 npy_uint8 *mem_buff2 = NULL; 3135 npy_uint8 *a, *tau, *work; 3136 fortran_int min_m_n = fortran_int_min(m, n); 3137 size_t safe_min_m_n = min_m_n; 3138 size_t safe_m = m; 3139 size_t safe_n = n; 3140 3141 size_t a_size = safe_m * safe_n * sizeof(ftyp); 3142 size_t tau_size = safe_min_m_n * sizeof(ftyp); 3143 3144 fortran_int work_count; 3145 size_t work_size; 3146 fortran_int lda = fortran_int_max(1, m); 3147 3148 mem_buff = (npy_uint8 *)malloc(a_size + tau_size); 3149 3150 if (!mem_buff) 3151 goto error; 3152 3153 a = mem_buff; 3154 tau = a + a_size; 3155 memset(tau, 0, tau_size); 3156 3157 3158 params->M = m; 3159 params->N = n; 3160 params->A = (ftyp*)a; 3161 params->TAU = (ftyp*)tau; 3162 params->LDA = lda; 3163 3164 { 3165 /* compute optimal work size */ 3166 3167 ftyp work_size_query; 3168 3169 params->WORK = &work_size_query; 3170 params->LWORK = -1; 3171 3172 if (call_geqrf(params) != 0) 3173 goto error; 3174 3175 work_count = (fortran_int) *(ftyp*) params->WORK; 3176 3177 } 3178 3179 params->LWORK = fortran_int_max(fortran_int_max(1, n), work_count); 3180 3181 work_size = (size_t) params->LWORK * sizeof(ftyp); 3182 mem_buff2 = (npy_uint8 *)malloc(work_size); 3183 if (!mem_buff2) 3184 goto error; 3185 3186 work = mem_buff2; 3187 3188 params->WORK = (ftyp*)work; 3189 3190 return 1; 3191 error: 3192 TRACE_TXT("%s failed init\n", __FUNCTION__); 3193 free(mem_buff); 3194 free(mem_buff2); 3195 memset(params, 0, sizeof(*params)); 3196 3197 return 0; 3198} 3199 3200 3201static inline int 3202init_geqrf(GEQRF_PARAMS_t<fortran_doublecomplex> *params, 3203 fortran_int m, 3204 fortran_int n) 3205{ 3206using ftyp = fortran_doublecomplex; 3207 npy_uint8 *mem_buff = NULL; 3208 npy_uint8 *mem_buff2 = NULL; 3209 npy_uint8 *a, *tau, *work; 3210 fortran_int min_m_n = fortran_int_min(m, n); 3211 size_t safe_min_m_n = min_m_n; 3212 size_t safe_m = m; 3213 size_t safe_n = n; 3214 3215 size_t a_size = safe_m * safe_n * sizeof(ftyp); 3216 size_t tau_size = safe_min_m_n * sizeof(ftyp); 3217 3218 fortran_int work_count; 3219 size_t work_size; 3220 fortran_int lda = fortran_int_max(1, m); 3221 3222 mem_buff = (npy_uint8 *)malloc(a_size + tau_size); 3223 3224 if (!mem_buff) 3225 goto error; 3226 3227 a = mem_buff; 3228 tau = a + a_size; 3229 memset(tau, 0, tau_size); 3230 3231 3232 params->M = m; 3233 params->N = n; 3234 params->A = (ftyp*)a; 3235 params->TAU = (ftyp*)tau; 3236 params->LDA = lda; 3237 3238 { 3239 /* compute optimal work size */ 3240 3241 ftyp work_size_query; 3242 3243 params->WORK = &work_size_query; 3244 params->LWORK = -1; 3245 3246 if (call_geqrf(params) != 0) 3247 goto error; 3248 3249 work_count = (fortran_int) ((ftyp*)params->WORK)->r; 3250 3251 } 3252 3253 params->LWORK = fortran_int_max(fortran_int_max(1, n), 3254 work_count); 3255 3256 work_size = (size_t) params->LWORK * sizeof(ftyp); 3257 3258 mem_buff2 = (npy_uint8 *)malloc(work_size); 3259 if (!mem_buff2) 3260 goto error; 3261 3262 work = mem_buff2; 3263 3264 params->WORK = (ftyp*)work; 3265 3266 return 1; 3267 error: 3268 TRACE_TXT("%s failed init\n", __FUNCTION__); 3269 free(mem_buff); 3270 free(mem_buff2); 3271 memset(params, 0, sizeof(*params)); 3272 3273 return 0; 3274} 3275 3276 3277template<typename ftyp> 3278static inline void 3279release_geqrf(GEQRF_PARAMS_t<ftyp>* params) 3280{ 3281 /* A and WORK contain allocated blocks */ 3282 free(params->A); 3283 free(params->WORK); 3284 memset(params, 0, sizeof(*params)); 3285} 3286 3287template<typename typ> 3288static void 3289qr_r_raw(char **args, npy_intp const *dimensions, npy_intp const *steps, 3290 void *NPY_UNUSED(func)) 3291{ 3292using ftyp = fortran_type_t<typ>; 3293 3294 GEQRF_PARAMS_t<ftyp> params; 3295 int error_occurred = get_fp_invalid_and_clear(); 3296 fortran_int n, m; 3297 3298 INIT_OUTER_LOOP_2 3299 3300 m = (fortran_int)dimensions[0]; 3301 n = (fortran_int)dimensions[1]; 3302 3303 if (init_geqrf(&params, m, n)) { 3304 LINEARIZE_DATA_t a_in, tau_out; 3305 3306 init_linearize_data(&a_in, n, m, steps[1], steps[0]); 3307 init_linearize_data(&tau_out, 1, fortran_int_min(m, n), 1, steps[2]); 3308 3309 BEGIN_OUTER_LOOP_2 3310 int not_ok; 3311 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 3312 not_ok = call_geqrf(&params); 3313 if (!not_ok) { 3314 delinearize_matrix((typ*)args[0], (typ*)params.A, &a_in); 3315 delinearize_matrix((typ*)args[1], (typ*)params.TAU, &tau_out); 3316 } else { 3317 error_occurred = 1; 3318 nan_matrix((typ*)args[1], &tau_out); 3319 } 3320 END_OUTER_LOOP 3321 3322 release_geqrf(&params); 3323 } 3324 3325 set_fp_invalid_or_clear(error_occurred); 3326} 3327 3328 3329/* -------------------------------------------------------------------------- */ 3330 /* qr common code (modes - reduced and complete) */ 3331 3332template<typename typ> 3333struct GQR_PARAMS_t 3334{ 3335 fortran_int M; 3336 fortran_int MC; 3337 fortran_int MN; 3338 void* A; 3339 typ *Q; 3340 fortran_int LDA; 3341 typ* TAU; 3342 typ *WORK; 3343 fortran_int LWORK; 3344} ; 3345 3346static inline fortran_int 3347call_gqr(GQR_PARAMS_t<double> *params) 3348{ 3349 fortran_int rv; 3350 LAPACK(dorgqr)(&params->M, &params->MC, &params->MN, 3351 params->Q, &params->LDA, 3352 params->TAU, 3353 params->WORK, &params->LWORK, 3354 &rv); 3355 return rv; 3356} 3357static inline fortran_int 3358call_gqr(GQR_PARAMS_t<f2c_doublecomplex> *params) 3359{ 3360 fortran_int rv; 3361 LAPACK(zungqr)(&params->M, &params->MC, &params->MN, 3362 params->Q, &params->LDA, 3363 params->TAU, 3364 params->WORK, &params->LWORK, 3365 &rv); 3366 return rv; 3367} 3368 3369static inline int 3370init_gqr_common(GQR_PARAMS_t<fortran_doublereal> *params, 3371 fortran_int m, 3372 fortran_int n, 3373 fortran_int mc) 3374{ 3375using ftyp = fortran_doublereal; 3376 npy_uint8 *mem_buff = NULL; 3377 npy_uint8 *mem_buff2 = NULL; 3378 npy_uint8 *a, *q, *tau, *work; 3379 fortran_int min_m_n = fortran_int_min(m, n); 3380 size_t safe_mc = mc; 3381 size_t safe_min_m_n = min_m_n; 3382 size_t safe_m = m; 3383 size_t safe_n = n; 3384 size_t a_size = safe_m * safe_n * sizeof(ftyp); 3385 size_t q_size = safe_m * safe_mc * sizeof(ftyp); 3386 size_t tau_size = safe_min_m_n * sizeof(ftyp); 3387 3388 fortran_int work_count; 3389 size_t work_size; 3390 fortran_int lda = fortran_int_max(1, m); 3391 3392 mem_buff = (npy_uint8 *)malloc(q_size + tau_size + a_size); 3393 3394 if (!mem_buff) 3395 goto error; 3396 3397 q = mem_buff; 3398 tau = q + q_size; 3399 a = tau + tau_size; 3400 3401 3402 params->M = m; 3403 params->MC = mc; 3404 params->MN = min_m_n; 3405 params->A = a; 3406 params->Q = (ftyp*)q; 3407 params->TAU = (ftyp*)tau; 3408 params->LDA = lda; 3409 3410 { 3411 /* compute optimal work size */ 3412 ftyp work_size_query; 3413 3414 params->WORK = &work_size_query; 3415 params->LWORK = -1; 3416 3417 if (call_gqr(params) != 0) 3418 goto error; 3419 3420 work_count = (fortran_int) *(ftyp*) params->WORK; 3421 3422 } 3423 3424 params->LWORK = fortran_int_max(fortran_int_max(1, n), work_count); 3425 3426 work_size = (size_t) params->LWORK * sizeof(ftyp); 3427 3428 mem_buff2 = (npy_uint8 *)malloc(work_size); 3429 if (!mem_buff2) 3430 goto error; 3431 3432 work = mem_buff2; 3433 3434 params->WORK = (ftyp*)work; 3435 3436 return 1; 3437 error: 3438 TRACE_TXT("%s failed init\n", __FUNCTION__); 3439 free(mem_buff); 3440 free(mem_buff2); 3441 memset(params, 0, sizeof(*params)); 3442 3443 return 0; 3444} 3445 3446 3447static inline int 3448init_gqr_common(GQR_PARAMS_t<fortran_doublecomplex> *params, 3449 fortran_int m, 3450 fortran_int n, 3451 fortran_int mc) 3452{ 3453using ftyp=fortran_doublecomplex; 3454 npy_uint8 *mem_buff = NULL; 3455 npy_uint8 *mem_buff2 = NULL; 3456 npy_uint8 *a, *q, *tau, *work; 3457 fortran_int min_m_n = fortran_int_min(m, n); 3458 size_t safe_mc = mc; 3459 size_t safe_min_m_n = min_m_n; 3460 size_t safe_m = m; 3461 size_t safe_n = n; 3462 3463 size_t a_size = safe_m * safe_n * sizeof(ftyp); 3464 size_t q_size = safe_m * safe_mc * sizeof(ftyp); 3465 size_t tau_size = safe_min_m_n * sizeof(ftyp); 3466 3467 fortran_int work_count; 3468 size_t work_size; 3469 fortran_int lda = fortran_int_max(1, m); 3470 3471 mem_buff = (npy_uint8 *)malloc(q_size + tau_size + a_size); 3472 3473 if (!mem_buff) 3474 goto error; 3475 3476 q = mem_buff; 3477 tau = q + q_size; 3478 a = tau + tau_size; 3479 3480 3481 params->M = m; 3482 params->MC = mc; 3483 params->MN = min_m_n; 3484 params->A = a; 3485 params->Q = (ftyp*)q; 3486 params->TAU = (ftyp*)tau; 3487 params->LDA = lda; 3488 3489 { 3490 /* compute optimal work size */ 3491 ftyp work_size_query; 3492 3493 params->WORK = &work_size_query; 3494 params->LWORK = -1; 3495 3496 if (call_gqr(params) != 0) 3497 goto error; 3498 3499 work_count = (fortran_int) ((ftyp*)params->WORK)->r; 3500 3501 } 3502 3503 params->LWORK = fortran_int_max(fortran_int_max(1, n), 3504 work_count); 3505 3506 work_size = (size_t) params->LWORK * sizeof(ftyp); 3507 3508 mem_buff2 = (npy_uint8 *)malloc(work_size); 3509 if (!mem_buff2) 3510 goto error; 3511 3512 work = mem_buff2; 3513 3514 params->WORK = (ftyp*)work; 3515 params->LWORK = work_count; 3516 3517 return 1; 3518 error: 3519 TRACE_TXT("%s failed init\n", __FUNCTION__); 3520 free(mem_buff); 3521 free(mem_buff2); 3522 memset(params, 0, sizeof(*params)); 3523 3524 return 0; 3525} 3526 3527/* -------------------------------------------------------------------------- */ 3528 /* qr (modes - reduced) */ 3529 3530 3531template<typename typ> 3532static inline void 3533dump_gqr_params(const char *name, 3534 GQR_PARAMS_t<typ> *params) 3535{ 3536 TRACE_TXT("\n%s:\n"\ 3537 3538 "%14s: %18p\n"\ 3539 "%14s: %18p\n"\ 3540 "%14s: %18p\n"\ 3541 "%14s: %18d\n"\ 3542 "%14s: %18d\n"\ 3543 "%14s: %18d\n"\ 3544 "%14s: %18d\n"\ 3545 "%14s: %18d\n", 3546 3547 name, 3548 3549 "Q", params->Q, 3550 "TAU", params->TAU, 3551 "WORK", params->WORK, 3552 3553 "M", (int)params->M, 3554 "MC", (int)params->MC, 3555 "MN", (int)params->MN, 3556 "LDA", (int)params->LDA, 3557 "LWORK", (int)params->LWORK); 3558} 3559 3560template<typename ftyp> 3561static inline int 3562init_gqr(GQR_PARAMS_t<ftyp> *params, 3563 fortran_int m, 3564 fortran_int n) 3565{ 3566 return init_gqr_common( 3567 params, m, n, 3568 fortran_int_min(m, n)); 3569} 3570 3571 3572template<typename typ> 3573static inline void 3574release_gqr(GQR_PARAMS_t<typ>* params) 3575{ 3576 /* A and WORK contain allocated blocks */ 3577 free(params->Q); 3578 free(params->WORK); 3579 memset(params, 0, sizeof(*params)); 3580} 3581 3582template<typename typ> 3583static void 3584qr_reduced(char **args, npy_intp const *dimensions, npy_intp const *steps, 3585 void *NPY_UNUSED(func)) 3586{ 3587using ftyp = fortran_type_t<typ>; 3588 GQR_PARAMS_t<ftyp> params; 3589 int error_occurred = get_fp_invalid_and_clear(); 3590 fortran_int n, m; 3591 3592 INIT_OUTER_LOOP_3 3593 3594 m = (fortran_int)dimensions[0]; 3595 n = (fortran_int)dimensions[1]; 3596 3597 if (init_gqr(&params, m, n)) { 3598 LINEARIZE_DATA_t a_in, tau_in, q_out; 3599 3600 init_linearize_data(&a_in, n, m, steps[1], steps[0]); 3601 init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]); 3602 init_linearize_data(&q_out, fortran_int_min(m, n), m, steps[4], steps[3]); 3603 3604 BEGIN_OUTER_LOOP_3 3605 int not_ok; 3606 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 3607 linearize_matrix((typ*)params.Q, (typ*)args[0], &a_in); 3608 linearize_matrix((typ*)params.TAU, (typ*)args[1], &tau_in); 3609 not_ok = call_gqr(&params); 3610 if (!not_ok) { 3611 delinearize_matrix((typ*)args[2], (typ*)params.Q, &q_out); 3612 } else { 3613 error_occurred = 1; 3614 nan_matrix((typ*)args[2], &q_out); 3615 } 3616 END_OUTER_LOOP 3617 3618 release_gqr(&params); 3619 } 3620 3621 set_fp_invalid_or_clear(error_occurred); 3622} 3623 3624/* -------------------------------------------------------------------------- */ 3625 /* qr (modes - complete) */ 3626 3627template<typename ftyp> 3628static inline int 3629init_gqr_complete(GQR_PARAMS_t<ftyp> *params, 3630 fortran_int m, 3631 fortran_int n) 3632{ 3633 return init_gqr_common(params, m, n, m); 3634} 3635 3636 3637template<typename typ> 3638static void 3639qr_complete(char **args, npy_intp const *dimensions, npy_intp const *steps, 3640 void *NPY_UNUSED(func)) 3641{ 3642using ftyp = fortran_type_t<typ>; 3643 GQR_PARAMS_t<ftyp> params; 3644 int error_occurred = get_fp_invalid_and_clear(); 3645 fortran_int n, m; 3646 3647 INIT_OUTER_LOOP_3 3648 3649 m = (fortran_int)dimensions[0]; 3650 n = (fortran_int)dimensions[1]; 3651 3652 3653 if (init_gqr_complete(&params, m, n)) { 3654 LINEARIZE_DATA_t a_in, tau_in, q_out; 3655 3656 init_linearize_data(&a_in, n, m, steps[1], steps[0]); 3657 init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]); 3658 init_linearize_data(&q_out, m, m, steps[4], steps[3]); 3659 3660 BEGIN_OUTER_LOOP_3 3661 int not_ok; 3662 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 3663 linearize_matrix((typ*)params.Q, (typ*)args[0], &a_in); 3664 linearize_matrix((typ*)params.TAU, (typ*)args[1], &tau_in); 3665 not_ok = call_gqr(&params); 3666 if (!not_ok) { 3667 delinearize_matrix((typ*)args[2], (typ*)params.Q, &q_out); 3668 } else { 3669 error_occurred = 1; 3670 nan_matrix((typ*)args[2], &q_out); 3671 } 3672 END_OUTER_LOOP 3673 3674 release_gqr(&params); 3675 } 3676 3677 set_fp_invalid_or_clear(error_occurred); 3678} 3679 3680/* -------------------------------------------------------------------------- */ 3681 /* least squares */ 3682 3683template<typename typ> 3684struct GELSD_PARAMS_t 3685{ 3686 fortran_int M; 3687 fortran_int N; 3688 fortran_int NRHS; 3689 typ *A; 3690 fortran_int LDA; 3691 typ *B; 3692 fortran_int LDB; 3693 basetype_t<typ> *S; 3694 basetype_t<typ> *RCOND; 3695 fortran_int RANK; 3696 typ *WORK; 3697 fortran_int LWORK; 3698 basetype_t<typ> *RWORK; 3699 fortran_int *IWORK; 3700}; 3701 3702template<typename typ> 3703static inline void 3704dump_gelsd_params(const char *name, 3705 GELSD_PARAMS_t<typ> *params) 3706{ 3707 TRACE_TXT("\n%s:\n"\ 3708 3709 "%14s: %18p\n"\ 3710 "%14s: %18p\n"\ 3711 "%14s: %18p\n"\ 3712 "%14s: %18p\n"\ 3713 "%14s: %18p\n"\ 3714 "%14s: %18p\n"\ 3715 3716 "%14s: %18d\n"\ 3717 "%14s: %18d\n"\ 3718 "%14s: %18d\n"\ 3719 "%14s: %18d\n"\ 3720 "%14s: %18d\n"\ 3721 "%14s: %18d\n"\ 3722 "%14s: %18d\n"\ 3723 3724 "%14s: %18p\n", 3725 3726 name, 3727 3728 "A", params->A, 3729 "B", params->B, 3730 "S", params->S, 3731 "WORK", params->WORK, 3732 "RWORK", params->RWORK, 3733 "IWORK", params->IWORK, 3734 3735 "M", (int)params->M, 3736 "N", (int)params->N, 3737 "NRHS", (int)params->NRHS, 3738 "LDA", (int)params->LDA, 3739 "LDB", (int)params->LDB, 3740 "LWORK", (int)params->LWORK, 3741 "RANK", (int)params->RANK, 3742 3743 "RCOND", params->RCOND); 3744} 3745 3746static inline fortran_int 3747call_gelsd(GELSD_PARAMS_t<fortran_real> *params) 3748{ 3749 fortran_int rv; 3750 LAPACK(sgelsd)(&params->M, &params->N, &params->NRHS, 3751 params->A, &params->LDA, 3752 params->B, &params->LDB, 3753 params->S, 3754 params->RCOND, &params->RANK, 3755 params->WORK, &params->LWORK, 3756 params->IWORK, 3757 &rv); 3758 return rv; 3759} 3760 3761 3762static inline fortran_int 3763call_gelsd(GELSD_PARAMS_t<fortran_doublereal> *params) 3764{ 3765 fortran_int rv; 3766 LAPACK(dgelsd)(&params->M, &params->N, &params->NRHS, 3767 params->A, &params->LDA, 3768 params->B, &params->LDB, 3769 params->S, 3770 params->RCOND, &params->RANK, 3771 params->WORK, &params->LWORK, 3772 params->IWORK, 3773 &rv); 3774 return rv; 3775} 3776 3777 3778template<typename ftyp> 3779static inline int 3780init_gelsd(GELSD_PARAMS_t<ftyp> *params, 3781 fortran_int m, 3782 fortran_int n, 3783 fortran_int nrhs, 3784scalar_trait) 3785{ 3786 npy_uint8 *mem_buff = NULL; 3787 npy_uint8 *mem_buff2 = NULL; 3788 npy_uint8 *a, *b, *s, *work, *iwork; 3789 fortran_int min_m_n = fortran_int_min(m, n); 3790 fortran_int max_m_n = fortran_int_max(m, n); 3791 size_t safe_min_m_n = min_m_n; 3792 size_t safe_max_m_n = max_m_n; 3793 size_t safe_m = m; 3794 size_t safe_n = n; 3795 size_t safe_nrhs = nrhs; 3796 3797 size_t a_size = safe_m * safe_n * sizeof(ftyp); 3798 size_t b_size = safe_max_m_n * safe_nrhs * sizeof(ftyp); 3799 size_t s_size = safe_min_m_n * sizeof(ftyp); 3800 3801 fortran_int work_count; 3802 size_t work_size; 3803 size_t iwork_size; 3804 fortran_int lda = fortran_int_max(1, m); 3805 fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n)); 3806 3807 size_t msize = a_size + b_size + s_size; 3808 mem_buff = (npy_uint8 *)malloc(msize != 0 ? msize : 1); 3809 3810 if (!mem_buff) { 3811 goto no_memory; 3812 } 3813 a = mem_buff; 3814 b = a + a_size; 3815 s = b + b_size; 3816 3817 params->M = m; 3818 params->N = n; 3819 params->NRHS = nrhs; 3820 params->A = (ftyp*)a; 3821 params->B = (ftyp*)b; 3822 params->S = (ftyp*)s; 3823 params->LDA = lda; 3824 params->LDB = ldb; 3825 3826 { 3827 /* compute optimal work size */ 3828 ftyp work_size_query; 3829 fortran_int iwork_size_query; 3830 3831 params->WORK = &work_size_query; 3832 params->IWORK = &iwork_size_query; 3833 params->RWORK = NULL; 3834 params->LWORK = -1; 3835 3836 if (call_gelsd(params) != 0) { 3837 goto error; 3838 } 3839 work_count = (fortran_int)work_size_query; 3840 3841 work_size = (size_t) work_size_query * sizeof(ftyp); 3842 iwork_size = (size_t)iwork_size_query * sizeof(fortran_int); 3843 } 3844 3845 mem_buff2 = (npy_uint8 *)malloc(work_size + iwork_size); 3846 if (!mem_buff2) { 3847 goto no_memory; 3848 } 3849 work = mem_buff2; 3850 iwork = work + work_size; 3851 3852 params->WORK = (ftyp*)work; 3853 params->RWORK = NULL; 3854 params->IWORK = (fortran_int*)iwork; 3855 params->LWORK = work_count; 3856 3857 return 1; 3858 3859 no_memory: 3860 NPY_ALLOW_C_API_DEF 3861 NPY_ALLOW_C_API; 3862 PyErr_NoMemory(); 3863 NPY_DISABLE_C_API; 3864 3865 error: 3866 TRACE_TXT("%s failed init\n", __FUNCTION__); 3867 free(mem_buff); 3868 free(mem_buff2); 3869 memset(params, 0, sizeof(*params)); 3870 return 0; 3871} 3872 3873static inline fortran_int 3874call_gelsd(GELSD_PARAMS_t<fortran_complex> *params) 3875{ 3876 fortran_int rv; 3877 LAPACK(cgelsd)(&params->M, &params->N, &params->NRHS, 3878 params->A, &params->LDA, 3879 params->B, &params->LDB, 3880 params->S, 3881 params->RCOND, &params->RANK, 3882 params->WORK, &params->LWORK, 3883 params->RWORK, (fortran_int*)params->IWORK, 3884 &rv); 3885 return rv; 3886} 3887 3888static inline fortran_int 3889call_gelsd(GELSD_PARAMS_t<fortran_doublecomplex> *params) 3890{ 3891 fortran_int rv; 3892 LAPACK(zgelsd)(&params->M, &params->N, &params->NRHS, 3893 params->A, &params->LDA, 3894 params->B, &params->LDB, 3895 params->S, 3896 params->RCOND, &params->RANK, 3897 params->WORK, &params->LWORK, 3898 params->RWORK, (fortran_int*)params->IWORK, 3899 &rv); 3900 return rv; 3901} 3902 3903 3904template<typename ftyp> 3905static inline int 3906init_gelsd(GELSD_PARAMS_t<ftyp> *params, 3907 fortran_int m, 3908 fortran_int n, 3909 fortran_int nrhs, 3910complex_trait) 3911{ 3912using frealtyp = basetype_t<ftyp>; 3913 npy_uint8 *mem_buff = NULL; 3914 npy_uint8 *mem_buff2 = NULL; 3915 npy_uint8 *a, *b, *s, *work, *iwork, *rwork; 3916 fortran_int min_m_n = fortran_int_min(m, n); 3917 fortran_int max_m_n = fortran_int_max(m, n); 3918 size_t safe_min_m_n = min_m_n; 3919 size_t safe_max_m_n = max_m_n; 3920 size_t safe_m = m; 3921 size_t safe_n = n; 3922 size_t safe_nrhs = nrhs; 3923 3924 size_t a_size = safe_m * safe_n * sizeof(ftyp); 3925 size_t b_size = safe_max_m_n * safe_nrhs * sizeof(ftyp); 3926 size_t s_size = safe_min_m_n * sizeof(frealtyp); 3927 3928 fortran_int work_count; 3929 size_t work_size, rwork_size, iwork_size; 3930 fortran_int lda = fortran_int_max(1, m); 3931 fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n)); 3932 3933 size_t msize = a_size + b_size + s_size; 3934 mem_buff = (npy_uint8 *)malloc(msize != 0 ? msize : 1); 3935 3936 if (!mem_buff) { 3937 goto no_memory; 3938 } 3939 3940 a = mem_buff; 3941 b = a + a_size; 3942 s = b + b_size; 3943 3944 params->M = m; 3945 params->N = n; 3946 params->NRHS = nrhs; 3947 params->A = (ftyp*)a; 3948 params->B = (ftyp*)b; 3949 params->S = (frealtyp*)s; 3950 params->LDA = lda; 3951 params->LDB = ldb; 3952 3953 { 3954 /* compute optimal work size */ 3955 ftyp work_size_query; 3956 frealtyp rwork_size_query; 3957 fortran_int iwork_size_query; 3958 3959 params->WORK = &work_size_query; 3960 params->IWORK = &iwork_size_query; 3961 params->RWORK = &rwork_size_query; 3962 params->LWORK = -1; 3963 3964 if (call_gelsd(params) != 0) { 3965 goto error; 3966 } 3967 3968 work_count = (fortran_int)work_size_query.r; 3969 3970 work_size = (size_t )work_size_query.r * sizeof(ftyp); 3971 rwork_size = (size_t)rwork_size_query * sizeof(frealtyp); 3972 iwork_size = (size_t)iwork_size_query * sizeof(fortran_int); 3973 } 3974 3975 mem_buff2 = (npy_uint8 *)malloc(work_size + rwork_size + iwork_size); 3976 if (!mem_buff2) { 3977 goto no_memory; 3978 } 3979 3980 work = mem_buff2; 3981 rwork = work + work_size; 3982 iwork = rwork + rwork_size; 3983 3984 params->WORK = (ftyp*)work; 3985 params->RWORK = (frealtyp*)rwork; 3986 params->IWORK = (fortran_int*)iwork; 3987 params->LWORK = work_count; 3988 3989 return 1; 3990 3991 no_memory: 3992 NPY_ALLOW_C_API_DEF 3993 NPY_ALLOW_C_API; 3994 PyErr_NoMemory(); 3995 NPY_DISABLE_C_API; 3996 3997 error: 3998 TRACE_TXT("%s failed init\n", __FUNCTION__); 3999 free(mem_buff); 4000 free(mem_buff2); 4001 memset(params, 0, sizeof(*params)); 4002 4003 return 0; 4004} 4005 4006template<typename ftyp> 4007static inline void 4008release_gelsd(GELSD_PARAMS_t<ftyp>* params) 4009{ 4010 /* A and WORK contain allocated blocks */ 4011 free(params->A); 4012 free(params->WORK); 4013 memset(params, 0, sizeof(*params)); 4014} 4015 4016/** Compute the squared l2 norm of a contiguous vector */ 4017template<typename typ> 4018static basetype_t<typ> 4019abs2(typ *p, npy_intp n, scalar_trait) { 4020 npy_intp i; 4021 basetype_t<typ> res = 0; 4022 for (i = 0; i < n; i++) { 4023 typ el = p[i]; 4024 res += el*el; 4025 } 4026 return res; 4027} 4028template<typename typ> 4029static basetype_t<typ> 4030abs2(typ *p, npy_intp n, complex_trait) { 4031 npy_intp i; 4032 basetype_t<typ> res = 0; 4033 for (i = 0; i < n; i++) { 4034 typ el = p[i]; 4035 res += RE(&el)*RE(&el) + IM(&el)*IM(&el); 4036 } 4037 return res; 4038} 4039 4040 4041template<typename typ> 4042static void 4043lstsq(char **args, npy_intp const *dimensions, npy_intp const *steps, 4044 void *NPY_UNUSED(func)) 4045{ 4046using ftyp = fortran_type_t<typ>; 4047using basetyp = basetype_t<typ>; 4048 GELSD_PARAMS_t<ftyp> params; 4049 int error_occurred = get_fp_invalid_and_clear(); 4050 fortran_int n, m, nrhs; 4051 fortran_int excess; 4052 4053 INIT_OUTER_LOOP_7 4054 4055 m = (fortran_int)dimensions[0]; 4056 n = (fortran_int)dimensions[1]; 4057 nrhs = (fortran_int)dimensions[2]; 4058 excess = m - n; 4059 4060 if (init_gelsd(&params, m, n, nrhs, dispatch_scalar<ftyp>{})) { 4061 LINEARIZE_DATA_t a_in, b_in, x_out, s_out, r_out; 4062 4063 init_linearize_data(&a_in, n, m, steps[1], steps[0]); 4064 init_linearize_data_ex(&b_in, nrhs, m, steps[3], steps[2], fortran_int_max(n, m)); 4065 init_linearize_data_ex(&x_out, nrhs, n, steps[5], steps[4], fortran_int_max(n, m)); 4066 init_linearize_data(&r_out, 1, nrhs, 1, steps[6]); 4067 init_linearize_data(&s_out, 1, fortran_int_min(n, m), 1, steps[7]); 4068 4069 BEGIN_OUTER_LOOP_7 4070 int not_ok; 4071 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 4072 linearize_matrix((typ*)params.B, (typ*)args[1], &b_in); 4073 params.RCOND = (basetyp*)args[2]; 4074 not_ok = call_gelsd(&params); 4075 if (!not_ok) { 4076 delinearize_matrix((typ*)args[3], (typ*)params.B, &x_out); 4077 *(npy_int*) args[5] = params.RANK; 4078 delinearize_matrix((basetyp*)args[6], (basetyp*)params.S, &s_out); 4079 4080 /* Note that linalg.lstsq discards this when excess == 0 */ 4081 if (excess >= 0 && params.RANK == n) { 4082 /* Compute the residuals as the square sum of each column */ 4083 int i; 4084 char *resid = args[4]; 4085 ftyp *components = (ftyp *)params.B + n; 4086 for (i = 0; i < nrhs; i++) { 4087 ftyp *vector = components + i*m; 4088 /* Numpy and fortran floating types are the same size, 4089 * so this cast is safe */ 4090 basetyp abs = abs2((typ *)vector, excess, 4091dispatch_scalar<typ>{}); 4092 memcpy( 4093 resid + i*r_out.column_strides, 4094 &abs, sizeof(abs)); 4095 } 4096 } 4097 else { 4098 /* Note that this is always discarded by linalg.lstsq */ 4099 nan_matrix((basetyp*)args[4], &r_out); 4100 } 4101 } else { 4102 error_occurred = 1; 4103 nan_matrix((typ*)args[3], &x_out); 4104 nan_matrix((basetyp*)args[4], &r_out); 4105 *(npy_int*) args[5] = -1; 4106 nan_matrix((basetyp*)args[6], &s_out); 4107 } 4108 END_OUTER_LOOP 4109 4110 release_gelsd(&params); 4111 } 4112 4113 set_fp_invalid_or_clear(error_occurred); 4114} 4115 4116#pragma GCC diagnostic pop 4117 4118/* -------------------------------------------------------------------------- */ 4119 /* gufunc registration */ 4120 4121static void *array_of_nulls[] = { 4122 (void *)NULL, 4123 (void *)NULL, 4124 (void *)NULL, 4125 (void *)NULL, 4126 4127 (void *)NULL, 4128 (void *)NULL, 4129 (void *)NULL, 4130 (void *)NULL, 4131 4132 (void *)NULL, 4133 (void *)NULL, 4134 (void *)NULL, 4135 (void *)NULL, 4136 4137 (void *)NULL, 4138 (void *)NULL, 4139 (void *)NULL, 4140 (void *)NULL 4141}; 4142 4143#define FUNC_ARRAY_NAME(NAME) NAME ## _funcs 4144 4145#define GUFUNC_FUNC_ARRAY_REAL(NAME) \ 4146 static PyUFuncGenericFunction \ 4147 FUNC_ARRAY_NAME(NAME)[] = { \ 4148 FLOAT_ ## NAME, \ 4149 DOUBLE_ ## NAME \ 4150 } 4151 4152#define GUFUNC_FUNC_ARRAY_REAL_COMPLEX(NAME) \ 4153 static PyUFuncGenericFunction \ 4154 FUNC_ARRAY_NAME(NAME)[] = { \ 4155 FLOAT_ ## NAME, \ 4156 DOUBLE_ ## NAME, \ 4157 CFLOAT_ ## NAME, \ 4158 CDOUBLE_ ## NAME \ 4159 } 4160#define GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(NAME) \ 4161 static PyUFuncGenericFunction \ 4162 FUNC_ARRAY_NAME(NAME)[] = { \ 4163 NAME<npy_float, npy_float>, \ 4164 NAME<npy_double, npy_double>, \ 4165 NAME<npy_cfloat, npy_float>, \ 4166 NAME<npy_cdouble, npy_double> \ 4167 } 4168#define GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(NAME) \ 4169 static PyUFuncGenericFunction \ 4170 FUNC_ARRAY_NAME(NAME)[] = { \ 4171 NAME<npy_float>, \ 4172 NAME<npy_double>, \ 4173 NAME<npy_cfloat>, \ 4174 NAME<npy_cdouble> \ 4175 } 4176 4177/* There are problems with eig in complex single precision. 4178 * That kernel is disabled 4179 */ 4180#define GUFUNC_FUNC_ARRAY_EIG(NAME) \ 4181 static PyUFuncGenericFunction \ 4182 FUNC_ARRAY_NAME(NAME)[] = { \ 4183 NAME<fortran_complex,fortran_real>, \ 4184 NAME<fortran_doublecomplex,fortran_doublereal>, \ 4185 NAME<fortran_doublecomplex,fortran_doublecomplex> \ 4186 } 4187 4188/* The single precision functions are not used at all, 4189 * due to input data being promoted to double precision 4190 * in Python, so they are not implemented here. 4191 */ 4192#define GUFUNC_FUNC_ARRAY_QR(NAME) \ 4193 static PyUFuncGenericFunction \ 4194 FUNC_ARRAY_NAME(NAME)[] = { \ 4195 DOUBLE_ ## NAME, \ 4196 CDOUBLE_ ## NAME \ 4197 } 4198#define GUFUNC_FUNC_ARRAY_QR__(NAME) \ 4199 static PyUFuncGenericFunction \ 4200 FUNC_ARRAY_NAME(NAME)[] = { \ 4201 NAME<npy_double>, \ 4202 NAME<npy_cdouble> \ 4203 } 4204 4205 4206GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(slogdet); 4207GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(det); 4208GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eighlo); 4209GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eighup); 4210GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eigvalshlo); 4211GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eigvalshup); 4212GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(solve); 4213GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(solve1); 4214GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(inv); 4215GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(cholesky_lo); 4216GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_N); 4217GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_S); 4218GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_A); 4219GUFUNC_FUNC_ARRAY_QR__(qr_r_raw); 4220GUFUNC_FUNC_ARRAY_QR__(qr_reduced); 4221GUFUNC_FUNC_ARRAY_QR__(qr_complete); 4222GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(lstsq); 4223GUFUNC_FUNC_ARRAY_EIG(eig); 4224GUFUNC_FUNC_ARRAY_EIG(eigvals); 4225 4226static char equal_2_types[] = { 4227 NPY_FLOAT, NPY_FLOAT, 4228 NPY_DOUBLE, NPY_DOUBLE, 4229 NPY_CFLOAT, NPY_CFLOAT, 4230 NPY_CDOUBLE, NPY_CDOUBLE 4231}; 4232 4233static char equal_3_types[] = { 4234 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, 4235 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4236 NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT, 4237 NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE 4238}; 4239 4240/* second result is logdet, that will always be a REAL */ 4241static char slogdet_types[] = { 4242 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, 4243 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4244 NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, 4245 NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE 4246}; 4247 4248static char eigh_types[] = { 4249 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, 4250 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4251 NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, 4252 NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE 4253}; 4254 4255static char eighvals_types[] = { 4256 NPY_FLOAT, NPY_FLOAT, 4257 NPY_DOUBLE, NPY_DOUBLE, 4258 NPY_CFLOAT, NPY_FLOAT, 4259 NPY_CDOUBLE, NPY_DOUBLE 4260}; 4261 4262static char eig_types[] = { 4263 NPY_FLOAT, NPY_CFLOAT, NPY_CFLOAT, 4264 NPY_DOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, 4265 NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE 4266}; 4267 4268static char eigvals_types[] = { 4269 NPY_FLOAT, NPY_CFLOAT, 4270 NPY_DOUBLE, NPY_CDOUBLE, 4271 NPY_CDOUBLE, NPY_CDOUBLE 4272}; 4273 4274static char svd_1_1_types[] = { 4275 NPY_FLOAT, NPY_FLOAT, 4276 NPY_DOUBLE, NPY_DOUBLE, 4277 NPY_CFLOAT, NPY_FLOAT, 4278 NPY_CDOUBLE, NPY_DOUBLE 4279}; 4280 4281static char svd_1_3_types[] = { 4282 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, 4283 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4284 NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, 4285 NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE 4286}; 4287 4288/* A, tau */ 4289static char qr_r_raw_types[] = { 4290 NPY_DOUBLE, NPY_DOUBLE, 4291 NPY_CDOUBLE, NPY_CDOUBLE, 4292}; 4293 4294/* A, tau, q */ 4295static char qr_reduced_types[] = { 4296 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4297 NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, 4298}; 4299 4300/* A, tau, q */ 4301static char qr_complete_types[] = { 4302 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4303 NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, 4304}; 4305 4306/* A, b, rcond, x, resid, rank, s, */ 4307static char lstsq_types[] = { 4308 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT, 4309 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE, 4310 NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT, 4311 NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE, 4312}; 4313 4314typedef struct gufunc_descriptor_struct { 4315 const char *name; 4316 const char *signature; 4317 const char *doc; 4318 int ntypes; 4319 int nin; 4320 int nout; 4321 PyUFuncGenericFunction *funcs; 4322 char *types; 4323} GUFUNC_DESCRIPTOR_t; 4324 4325GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = { 4326 { 4327 "slogdet", 4328 "(m,m)->(),()", 4329 "slogdet on the last two dimensions and broadcast on the rest. \n"\ 4330 "Results in two arrays, one with sign and the other with log of the"\ 4331 " determinants. \n"\ 4332 " \"(m,m)->(),()\" \n", 4333 4, 1, 2, 4334 FUNC_ARRAY_NAME(slogdet), 4335 slogdet_types 4336 }, 4337 { 4338 "det", 4339 "(m,m)->()", 4340 "det of the last two dimensions and broadcast on the rest. \n"\ 4341 " \"(m,m)->()\" \n", 4342 4, 1, 1, 4343 FUNC_ARRAY_NAME(det), 4344 equal_2_types 4345 }, 4346 { 4347 "eigh_lo", 4348 "(m,m)->(m),(m,m)", 4349 "eigh on the last two dimension and broadcast to the rest, using"\ 4350 " lower triangle \n"\ 4351 "Results in a vector of eigenvalues and a matrix with the"\ 4352 "eigenvectors. \n"\ 4353 " \"(m,m)->(m),(m,m)\" \n", 4354 4, 1, 2, 4355 FUNC_ARRAY_NAME(eighlo), 4356 eigh_types 4357 }, 4358 { 4359 "eigh_up", 4360 "(m,m)->(m),(m,m)", 4361 "eigh on the last two dimension and broadcast to the rest, using"\ 4362 " upper triangle. \n"\ 4363 "Results in a vector of eigenvalues and a matrix with the"\ 4364 " eigenvectors. \n"\ 4365 " \"(m,m)->(m),(m,m)\" \n", 4366 4, 1, 2, 4367 FUNC_ARRAY_NAME(eighup), 4368 eigh_types 4369 }, 4370 { 4371 "eigvalsh_lo", 4372 "(m,m)->(m)", 4373 "eigh on the last two dimension and broadcast to the rest, using"\ 4374 " lower triangle. \n"\ 4375 "Results in a vector of eigenvalues and a matrix with the"\ 4376 "eigenvectors. \n"\ 4377 " \"(m,m)->(m)\" \n", 4378 4, 1, 1, 4379 FUNC_ARRAY_NAME(eigvalshlo), 4380 eighvals_types 4381 }, 4382 { 4383 "eigvalsh_up", 4384 "(m,m)->(m)", 4385 "eigvalsh on the last two dimension and broadcast to the rest,"\ 4386 " using upper triangle. \n"\ 4387 "Results in a vector of eigenvalues and a matrix with the"\ 4388 "eigenvectors.\n"\ 4389 " \"(m,m)->(m)\" \n", 4390 4, 1, 1, 4391 FUNC_ARRAY_NAME(eigvalshup), 4392 eighvals_types 4393 }, 4394 { 4395 "solve", 4396 "(m,m),(m,n)->(m,n)", 4397 "solve the system a x = b, on the last two dimensions, broadcast"\ 4398 " to the rest. \n"\ 4399 "Results in a matrices with the solutions. \n"\ 4400 " \"(m,m),(m,n)->(m,n)\" \n", 4401 4, 2, 1, 4402 FUNC_ARRAY_NAME(solve), 4403 equal_3_types 4404 }, 4405 { 4406 "solve1", 4407 "(m,m),(m)->(m)", 4408 "solve the system a x = b, for b being a vector, broadcast in"\ 4409 " the outer dimensions. \n"\ 4410 "Results in vectors with the solutions. \n"\ 4411 " \"(m,m),(m)->(m)\" \n", 4412 4, 2, 1, 4413 FUNC_ARRAY_NAME(solve1), 4414 equal_3_types 4415 }, 4416 { 4417 "inv", 4418 "(m, m)->(m, m)", 4419 "compute the inverse of the last two dimensions and broadcast"\ 4420 " to the rest. \n"\ 4421 "Results in the inverse matrices. \n"\ 4422 " \"(m,m)->(m,m)\" \n", 4423 4, 1, 1, 4424 FUNC_ARRAY_NAME(inv), 4425 equal_2_types 4426 }, 4427 { 4428 "cholesky_lo", 4429 "(m,m)->(m,m)", 4430 "cholesky decomposition of hermitian positive-definite matrices. \n"\ 4431 "Broadcast to all outer dimensions. \n"\ 4432 " \"(m,m)->(m,m)\" \n", 4433 4, 1, 1, 4434 FUNC_ARRAY_NAME(cholesky_lo), 4435 equal_2_types 4436 }, 4437 { 4438 "svd_m", 4439 "(m,n)->(m)", 4440 "svd when n>=m. ", 4441 4, 1, 1, 4442 FUNC_ARRAY_NAME(svd_N), 4443 svd_1_1_types 4444 }, 4445 { 4446 "svd_n", 4447 "(m,n)->(n)", 4448 "svd when n<=m", 4449 4, 1, 1, 4450 FUNC_ARRAY_NAME(svd_N), 4451 svd_1_1_types 4452 }, 4453 { 4454 "svd_m_s", 4455 "(m,n)->(m,m),(m),(m,n)", 4456 "svd when m<=n", 4457 4, 1, 3, 4458 FUNC_ARRAY_NAME(svd_S), 4459 svd_1_3_types 4460 }, 4461 { 4462 "svd_n_s", 4463 "(m,n)->(m,n),(n),(n,n)", 4464 "svd when m>=n", 4465 4, 1, 3, 4466 FUNC_ARRAY_NAME(svd_S), 4467 svd_1_3_types 4468 }, 4469 { 4470 "svd_m_f", 4471 "(m,n)->(m,m),(m),(n,n)", 4472 "svd when m<=n", 4473 4, 1, 3, 4474 FUNC_ARRAY_NAME(svd_A), 4475 svd_1_3_types 4476 }, 4477 { 4478 "svd_n_f", 4479 "(m,n)->(m,m),(n),(n,n)", 4480 "svd when m>=n", 4481 4, 1, 3, 4482 FUNC_ARRAY_NAME(svd_A), 4483 svd_1_3_types 4484 }, 4485 { 4486 "eig", 4487 "(m,m)->(m),(m,m)", 4488 "eig on the last two dimension and broadcast to the rest. \n"\ 4489 "Results in a vector with the eigenvalues and a matrix with the"\ 4490 " eigenvectors. \n"\ 4491 " \"(m,m)->(m),(m,m)\" \n", 4492 3, 1, 2, 4493 FUNC_ARRAY_NAME(eig), 4494 eig_types 4495 }, 4496 { 4497 "eigvals", 4498 "(m,m)->(m)", 4499 "eigvals on the last two dimension and broadcast to the rest. \n"\ 4500 "Results in a vector of eigenvalues. \n", 4501 3, 1, 1, 4502 FUNC_ARRAY_NAME(eigvals), 4503 eigvals_types 4504 }, 4505 { 4506 "qr_r_raw_m", 4507 "(m,n)->(m)", 4508 "Compute TAU vector for the last two dimensions \n"\ 4509 "and broadcast to the rest. For m <= n. \n", 4510 2, 1, 1, 4511 FUNC_ARRAY_NAME(qr_r_raw), 4512 qr_r_raw_types 4513 }, 4514 { 4515 "qr_r_raw_n", 4516 "(m,n)->(n)", 4517 "Compute TAU vector for the last two dimensions \n"\ 4518 "and broadcast to the rest. For m > n. \n", 4519 2, 1, 1, 4520 FUNC_ARRAY_NAME(qr_r_raw), 4521 qr_r_raw_types 4522 }, 4523 { 4524 "qr_reduced", 4525 "(m,n),(k)->(m,k)", 4526 "Compute Q matrix for the last two dimensions \n"\ 4527 "and broadcast to the rest. \n", 4528 2, 2, 1, 4529 FUNC_ARRAY_NAME(qr_reduced), 4530 qr_reduced_types 4531 }, 4532 { 4533 "qr_complete", 4534 "(m,n),(n)->(m,m)", 4535 "Compute Q matrix for the last two dimensions \n"\ 4536 "and broadcast to the rest. For m > n. \n", 4537 2, 2, 1, 4538 FUNC_ARRAY_NAME(qr_complete), 4539 qr_complete_types 4540 }, 4541 { 4542 "lstsq_m", 4543 "(m,n),(m,nrhs),()->(n,nrhs),(nrhs),(),(m)", 4544 "least squares on the last two dimensions and broadcast to the rest. \n"\ 4545 "For m <= n. \n", 4546 4, 3, 4, 4547 FUNC_ARRAY_NAME(lstsq), 4548 lstsq_types 4549 }, 4550 { 4551 "lstsq_n", 4552 "(m,n),(m,nrhs),()->(n,nrhs),(nrhs),(),(n)", 4553 "least squares on the last two dimensions and broadcast to the rest. \n"\ 4554 "For m >= n, meaning that residuals are produced. \n", 4555 4, 3, 4, 4556 FUNC_ARRAY_NAME(lstsq), 4557 lstsq_types 4558 } 4559}; 4560 4561static int 4562addUfuncs(PyObject *dictionary) { 4563 PyObject *f; 4564 int i; 4565 const int gufunc_count = sizeof(gufunc_descriptors)/ 4566 sizeof(gufunc_descriptors[0]); 4567 for (i = 0; i < gufunc_count; i++) { 4568 GUFUNC_DESCRIPTOR_t* d = &gufunc_descriptors[i]; 4569 f = PyUFunc_FromFuncAndDataAndSignature(d->funcs, 4570 array_of_nulls, 4571 d->types, 4572 d->ntypes, 4573 d->nin, 4574 d->nout, 4575 PyUFunc_None, 4576 d->name, 4577 d->doc, 4578 0, 4579 d->signature); 4580 if (f == NULL) { 4581 return -1; 4582 } 4583#if 0 4584 dump_ufunc_object((PyUFuncObject*) f); 4585#endif 4586 int ret = PyDict_SetItemString(dictionary, d->name, f); 4587 Py_DECREF(f); 4588 if (ret < 0) { 4589 return -1; 4590 } 4591 } 4592 return 0; 4593} 4594 4595 4596 4597/* -------------------------------------------------------------------------- */ 4598 /* Module initialization stuff */ 4599 4600static PyMethodDef UMath_LinAlgMethods[] = { 4601 {NULL, NULL, 0, NULL} /* Sentinel */ 4602}; 4603 4604static struct PyModuleDef moduledef = { 4605 PyModuleDef_HEAD_INIT, 4606 UMATH_LINALG_MODULE_NAME, 4607 NULL, 4608 -1, 4609 UMath_LinAlgMethods, 4610 NULL, 4611 NULL, 4612 NULL, 4613 NULL 4614}; 4615 4616PyMODINIT_FUNC PyInit__umath_linalg(void) 4617{ 4618 PyObject *m; 4619 PyObject *d; 4620 PyObject *version; 4621 4622 m = PyModule_Create(&moduledef); 4623 if (m == NULL) { 4624 return NULL; 4625 } 4626 4627 import_array(); 4628 import_ufunc(); 4629 4630 d = PyModule_GetDict(m); 4631 if (d == NULL) { 4632 return NULL; 4633 } 4634 4635 version = PyUnicode_FromString(umath_linalg_version_string); 4636 if (version == NULL) { 4637 return NULL; 4638 } 4639 int ret = PyDict_SetItemString(d, "__version__", version); 4640 Py_DECREF(version); 4641 if (ret < 0) { 4642 return NULL; 4643 } 4644 4645 /* Load the ufunc operators into the module's namespace */ 4646 if (addUfuncs(d) < 0) { 4647 return NULL; 4648 } 4649 4650#ifdef HAVE_BLAS_ILP64 4651 PyDict_SetItemString(d, "_ilp64", Py_True); 4652#else 4653 PyDict_SetItemString(d, "_ilp64", Py_False); 4654#endif 4655 4656 return m; 4657}