fork of https://github.com/sourcegraph/zoekt
0

Configure Feed

Select the types of activity you want to include in your feed.

Add benchmark for ctags conversion (#679)

This change adds a benchmark for the conversion from ctags output to Zoekt
document data, plus a tiny optimization to presize the symbol slices.

+4696 -2
+2 -2
build/ctags.go
··· 112 112 func tagsToSections(content []byte, tags []*ctags.Entry) ([]zoekt.DocumentSection, []*zoekt.Symbol, error) { 113 113 nls := newLinesIndices(content) 114 114 nls = append(nls, uint32(len(content))) 115 - var symOffsets []zoekt.DocumentSection 116 - var symMetaData []*zoekt.Symbol 115 + symOffsets := make([]zoekt.DocumentSection, 0, len(tags)) 116 + symMetaData := make([]*zoekt.Symbol, 0, len(tags)) 117 117 118 118 for _, t := range tags { 119 119 if t.Line <= 0 {
+37
build/ctags_test.go
··· 15 15 package build 16 16 17 17 import ( 18 + "os" 18 19 "reflect" 19 20 "testing" 20 21 ··· 229 230 }) 230 231 } 231 232 } 233 + 234 + func BenchmarkTagsToSections(b *testing.B) { 235 + if checkCTags() == "" { 236 + b.Skip("ctags not available") 237 + } 238 + 239 + file, err := os.ReadFile("./testdata/large_file.cc") 240 + parser, err := ctags.NewParser(ctags.UniversalCTags, "universal-ctags") 241 + if err != nil { 242 + b.Fatal(err) 243 + } 244 + 245 + entries, err := parser.Parse("./testdata/large_file.cc", file) 246 + if err != nil { 247 + b.Fatal(err) 248 + } 249 + 250 + secs, _, err := tagsToSections(file, entries) 251 + if err != nil { 252 + b.Fatal(err) 253 + } 254 + 255 + if len(secs) != 439 { 256 + b.Fatalf("got %d sections, want 439 sections", len(secs)) 257 + } 258 + 259 + b.ResetTimer() 260 + b.ReportAllocs() 261 + 262 + for n := 0; n < b.N; n++ { 263 + _, _, err := tagsToSections(file, entries) 264 + if err != nil { 265 + b.Fatal(err) 266 + } 267 + } 268 + }
+4657
build/testdata/large_file.cc
··· 1 + /* -*- c -*- */ 2 + 3 + /* 4 + ***************************************************************************** 5 + ** INCLUDES ** 6 + ***************************************************************************** 7 + */ 8 + #define PY_SSIZE_T_CLEAN 9 + #include <Python.h> 10 + 11 + #define NPY_NO_DEPRECATED_API NPY_API_VERSION 12 + #include "numpy/arrayobject.h" 13 + #include "numpy/ufuncobject.h" 14 + #include "numpy/npy_math.h" 15 + 16 + #include "npy_pycompat.h" 17 + 18 + #include "npy_config.h" 19 + 20 + #include "npy_cblas.h" 21 + 22 + #include <cstddef> 23 + #include <cstdio> 24 + #include <cassert> 25 + #include <cmath> 26 + #include <type_traits> 27 + #include <utility> 28 + 29 + 30 + static const char* umath_linalg_version_string = "0.1.5"; 31 + 32 + /* 33 + **************************************************************************** 34 + * Debugging support * 35 + **************************************************************************** 36 + */ 37 + #define TRACE_TXT(...) do { fprintf (stderr, __VA_ARGS__); } while (0) 38 + #define STACK_TRACE do {} while (0) 39 + #define TRACE\ 40 + do { \ 41 + fprintf (stderr, \ 42 + "%s:%d:%s\n", \ 43 + __FILE__, \ 44 + __LINE__, \ 45 + __FUNCTION__); \ 46 + STACK_TRACE; \ 47 + } while (0) 48 + 49 + #if 0 50 + #if defined HAVE_EXECINFO_H 51 + #include <execinfo.h> 52 + #elif defined HAVE_LIBUNWIND_H 53 + #include <libunwind.h> 54 + #endif 55 + void 56 + dbg_stack_trace() 57 + { 58 + void *trace[32]; 59 + size_t size; 60 + 61 + size = backtrace(trace, sizeof(trace)/sizeof(trace[0])); 62 + backtrace_symbols_fd(trace, size, 1); 63 + } 64 + 65 + #undef STACK_TRACE 66 + #define STACK_TRACE do { dbg_stack_trace(); } while (0) 67 + #endif 68 + 69 + /* 70 + ***************************************************************************** 71 + * BLAS/LAPACK calling macros * 72 + ***************************************************************************** 73 + */ 74 + 75 + #define FNAME(x) BLAS_FUNC(x) 76 + 77 + typedef CBLAS_INT fortran_int; 78 + 79 + typedef struct { float r, i; } f2c_complex; 80 + typedef struct { double r, i; } f2c_doublecomplex; 81 + /* typedef long int (*L_fp)(); */ 82 + 83 + typedef float fortran_real; 84 + typedef double fortran_doublereal; 85 + typedef f2c_complex fortran_complex; 86 + typedef f2c_doublecomplex fortran_doublecomplex; 87 + 88 + extern "C" fortran_int 89 + FNAME(sgeev)(char *jobvl, char *jobvr, fortran_int *n, 90 + float a[], fortran_int *lda, float wr[], float wi[], 91 + float vl[], fortran_int *ldvl, float vr[], fortran_int *ldvr, 92 + float work[], fortran_int lwork[], 93 + fortran_int *info); 94 + extern "C" fortran_int 95 + FNAME(dgeev)(char *jobvl, char *jobvr, fortran_int *n, 96 + double a[], fortran_int *lda, double wr[], double wi[], 97 + double vl[], fortran_int *ldvl, double vr[], fortran_int *ldvr, 98 + double work[], fortran_int lwork[], 99 + fortran_int *info); 100 + extern "C" fortran_int 101 + FNAME(cgeev)(char *jobvl, char *jobvr, fortran_int *n, 102 + f2c_complex a[], fortran_int *lda, 103 + f2c_complex w[], 104 + f2c_complex vl[], fortran_int *ldvl, 105 + f2c_complex vr[], fortran_int *ldvr, 106 + f2c_complex work[], fortran_int *lwork, 107 + float rwork[], 108 + fortran_int *info); 109 + extern "C" fortran_int 110 + FNAME(zgeev)(char *jobvl, char *jobvr, fortran_int *n, 111 + f2c_doublecomplex a[], fortran_int *lda, 112 + f2c_doublecomplex w[], 113 + f2c_doublecomplex vl[], fortran_int *ldvl, 114 + f2c_doublecomplex vr[], fortran_int *ldvr, 115 + f2c_doublecomplex work[], fortran_int *lwork, 116 + double rwork[], 117 + fortran_int *info); 118 + 119 + extern "C" fortran_int 120 + FNAME(ssyevd)(char *jobz, char *uplo, fortran_int *n, 121 + float a[], fortran_int *lda, float w[], float work[], 122 + fortran_int *lwork, fortran_int iwork[], fortran_int *liwork, 123 + fortran_int *info); 124 + extern "C" fortran_int 125 + FNAME(dsyevd)(char *jobz, char *uplo, fortran_int *n, 126 + double a[], fortran_int *lda, double w[], double work[], 127 + fortran_int *lwork, fortran_int iwork[], fortran_int *liwork, 128 + fortran_int *info); 129 + extern "C" fortran_int 130 + FNAME(cheevd)(char *jobz, char *uplo, fortran_int *n, 131 + f2c_complex a[], fortran_int *lda, 132 + float w[], f2c_complex work[], 133 + fortran_int *lwork, float rwork[], fortran_int *lrwork, fortran_int iwork[], 134 + fortran_int *liwork, 135 + fortran_int *info); 136 + extern "C" fortran_int 137 + FNAME(zheevd)(char *jobz, char *uplo, fortran_int *n, 138 + f2c_doublecomplex a[], fortran_int *lda, 139 + double w[], f2c_doublecomplex work[], 140 + fortran_int *lwork, double rwork[], fortran_int *lrwork, fortran_int iwork[], 141 + fortran_int *liwork, 142 + fortran_int *info); 143 + 144 + extern "C" fortran_int 145 + FNAME(sgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, 146 + float a[], fortran_int *lda, float b[], fortran_int *ldb, 147 + float s[], float *rcond, fortran_int *rank, 148 + float work[], fortran_int *lwork, fortran_int iwork[], 149 + fortran_int *info); 150 + extern "C" fortran_int 151 + FNAME(dgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, 152 + double a[], fortran_int *lda, double b[], fortran_int *ldb, 153 + double s[], double *rcond, fortran_int *rank, 154 + double work[], fortran_int *lwork, fortran_int iwork[], 155 + fortran_int *info); 156 + extern "C" fortran_int 157 + FNAME(cgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, 158 + f2c_complex a[], fortran_int *lda, 159 + f2c_complex b[], fortran_int *ldb, 160 + float s[], float *rcond, fortran_int *rank, 161 + f2c_complex work[], fortran_int *lwork, 162 + float rwork[], fortran_int iwork[], 163 + fortran_int *info); 164 + extern "C" fortran_int 165 + FNAME(zgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, 166 + f2c_doublecomplex a[], fortran_int *lda, 167 + f2c_doublecomplex b[], fortran_int *ldb, 168 + double s[], double *rcond, fortran_int *rank, 169 + f2c_doublecomplex work[], fortran_int *lwork, 170 + double rwork[], fortran_int iwork[], 171 + fortran_int *info); 172 + 173 + extern "C" fortran_int 174 + FNAME(dgeqrf)(fortran_int *m, fortran_int *n, double a[], fortran_int *lda, 175 + double tau[], double work[], 176 + fortran_int *lwork, fortran_int *info); 177 + extern "C" fortran_int 178 + FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int *lda, 179 + f2c_doublecomplex tau[], f2c_doublecomplex work[], 180 + fortran_int *lwork, fortran_int *info); 181 + 182 + extern "C" fortran_int 183 + FNAME(dorgqr)(fortran_int *m, fortran_int *n, fortran_int *k, double a[], fortran_int *lda, 184 + double tau[], double work[], 185 + fortran_int *lwork, fortran_int *info); 186 + extern "C" fortran_int 187 + FNAME(zungqr)(fortran_int *m, fortran_int *n, fortran_int *k, f2c_doublecomplex a[], 188 + fortran_int *lda, f2c_doublecomplex tau[], 189 + f2c_doublecomplex work[], fortran_int *lwork, fortran_int *info); 190 + 191 + extern "C" fortran_int 192 + FNAME(sgesv)(fortran_int *n, fortran_int *nrhs, 193 + float a[], fortran_int *lda, 194 + fortran_int ipiv[], 195 + float b[], fortran_int *ldb, 196 + fortran_int *info); 197 + extern "C" fortran_int 198 + FNAME(dgesv)(fortran_int *n, fortran_int *nrhs, 199 + double a[], fortran_int *lda, 200 + fortran_int ipiv[], 201 + double b[], fortran_int *ldb, 202 + fortran_int *info); 203 + extern "C" fortran_int 204 + FNAME(cgesv)(fortran_int *n, fortran_int *nrhs, 205 + f2c_complex a[], fortran_int *lda, 206 + fortran_int ipiv[], 207 + f2c_complex b[], fortran_int *ldb, 208 + fortran_int *info); 209 + extern "C" fortran_int 210 + FNAME(zgesv)(fortran_int *n, fortran_int *nrhs, 211 + f2c_doublecomplex a[], fortran_int *lda, 212 + fortran_int ipiv[], 213 + f2c_doublecomplex b[], fortran_int *ldb, 214 + fortran_int *info); 215 + 216 + extern "C" fortran_int 217 + FNAME(sgetrf)(fortran_int *m, fortran_int *n, 218 + float a[], fortran_int *lda, 219 + fortran_int ipiv[], 220 + fortran_int *info); 221 + extern "C" fortran_int 222 + FNAME(dgetrf)(fortran_int *m, fortran_int *n, 223 + double a[], fortran_int *lda, 224 + fortran_int ipiv[], 225 + fortran_int *info); 226 + extern "C" fortran_int 227 + FNAME(cgetrf)(fortran_int *m, fortran_int *n, 228 + f2c_complex a[], fortran_int *lda, 229 + fortran_int ipiv[], 230 + fortran_int *info); 231 + extern "C" fortran_int 232 + FNAME(zgetrf)(fortran_int *m, fortran_int *n, 233 + f2c_doublecomplex a[], fortran_int *lda, 234 + fortran_int ipiv[], 235 + fortran_int *info); 236 + 237 + extern "C" fortran_int 238 + FNAME(spotrf)(char *uplo, fortran_int *n, 239 + float a[], fortran_int *lda, 240 + fortran_int *info); 241 + extern "C" fortran_int 242 + FNAME(dpotrf)(char *uplo, fortran_int *n, 243 + double a[], fortran_int *lda, 244 + fortran_int *info); 245 + extern "C" fortran_int 246 + FNAME(cpotrf)(char *uplo, fortran_int *n, 247 + f2c_complex a[], fortran_int *lda, 248 + fortran_int *info); 249 + extern "C" fortran_int 250 + FNAME(zpotrf)(char *uplo, fortran_int *n, 251 + f2c_doublecomplex a[], fortran_int *lda, 252 + fortran_int *info); 253 + 254 + extern "C" fortran_int 255 + FNAME(sgesdd)(char *jobz, fortran_int *m, fortran_int *n, 256 + float a[], fortran_int *lda, float s[], float u[], 257 + fortran_int *ldu, float vt[], fortran_int *ldvt, float work[], 258 + fortran_int *lwork, fortran_int iwork[], fortran_int *info); 259 + extern "C" fortran_int 260 + FNAME(dgesdd)(char *jobz, fortran_int *m, fortran_int *n, 261 + double a[], fortran_int *lda, double s[], double u[], 262 + fortran_int *ldu, double vt[], fortran_int *ldvt, double work[], 263 + fortran_int *lwork, fortran_int iwork[], fortran_int *info); 264 + extern "C" fortran_int 265 + FNAME(cgesdd)(char *jobz, fortran_int *m, fortran_int *n, 266 + f2c_complex a[], fortran_int *lda, 267 + float s[], f2c_complex u[], fortran_int *ldu, 268 + f2c_complex vt[], fortran_int *ldvt, 269 + f2c_complex work[], fortran_int *lwork, 270 + float rwork[], fortran_int iwork[], fortran_int *info); 271 + extern "C" fortran_int 272 + FNAME(zgesdd)(char *jobz, fortran_int *m, fortran_int *n, 273 + f2c_doublecomplex a[], fortran_int *lda, 274 + double s[], f2c_doublecomplex u[], fortran_int *ldu, 275 + f2c_doublecomplex vt[], fortran_int *ldvt, 276 + f2c_doublecomplex work[], fortran_int *lwork, 277 + double rwork[], fortran_int iwork[], fortran_int *info); 278 + 279 + extern "C" fortran_int 280 + FNAME(spotrs)(char *uplo, fortran_int *n, fortran_int *nrhs, 281 + float a[], fortran_int *lda, 282 + float b[], fortran_int *ldb, 283 + fortran_int *info); 284 + extern "C" fortran_int 285 + FNAME(dpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs, 286 + double a[], fortran_int *lda, 287 + double b[], fortran_int *ldb, 288 + fortran_int *info); 289 + extern "C" fortran_int 290 + FNAME(cpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs, 291 + f2c_complex a[], fortran_int *lda, 292 + f2c_complex b[], fortran_int *ldb, 293 + fortran_int *info); 294 + extern "C" fortran_int 295 + FNAME(zpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs, 296 + f2c_doublecomplex a[], fortran_int *lda, 297 + f2c_doublecomplex b[], fortran_int *ldb, 298 + fortran_int *info); 299 + 300 + extern "C" fortran_int 301 + FNAME(spotri)(char *uplo, fortran_int *n, 302 + float a[], fortran_int *lda, 303 + fortran_int *info); 304 + extern "C" fortran_int 305 + FNAME(dpotri)(char *uplo, fortran_int *n, 306 + double a[], fortran_int *lda, 307 + fortran_int *info); 308 + extern "C" fortran_int 309 + FNAME(cpotri)(char *uplo, fortran_int *n, 310 + f2c_complex a[], fortran_int *lda, 311 + fortran_int *info); 312 + extern "C" fortran_int 313 + FNAME(zpotri)(char *uplo, fortran_int *n, 314 + f2c_doublecomplex a[], fortran_int *lda, 315 + fortran_int *info); 316 + 317 + extern "C" fortran_int 318 + FNAME(scopy)(fortran_int *n, 319 + float *sx, fortran_int *incx, 320 + float *sy, fortran_int *incy); 321 + extern "C" fortran_int 322 + FNAME(dcopy)(fortran_int *n, 323 + double *sx, fortran_int *incx, 324 + double *sy, fortran_int *incy); 325 + extern "C" fortran_int 326 + FNAME(ccopy)(fortran_int *n, 327 + f2c_complex *sx, fortran_int *incx, 328 + f2c_complex *sy, fortran_int *incy); 329 + extern "C" fortran_int 330 + FNAME(zcopy)(fortran_int *n, 331 + f2c_doublecomplex *sx, fortran_int *incx, 332 + f2c_doublecomplex *sy, fortran_int *incy); 333 + 334 + extern "C" float 335 + FNAME(sdot)(fortran_int *n, 336 + float *sx, fortran_int *incx, 337 + float *sy, fortran_int *incy); 338 + extern "C" double 339 + FNAME(ddot)(fortran_int *n, 340 + double *sx, fortran_int *incx, 341 + double *sy, fortran_int *incy); 342 + extern "C" void 343 + FNAME(cdotu)(f2c_complex *ret, fortran_int *n, 344 + f2c_complex *sx, fortran_int *incx, 345 + f2c_complex *sy, fortran_int *incy); 346 + extern "C" void 347 + FNAME(zdotu)(f2c_doublecomplex *ret, fortran_int *n, 348 + f2c_doublecomplex *sx, fortran_int *incx, 349 + f2c_doublecomplex *sy, fortran_int *incy); 350 + extern "C" void 351 + FNAME(cdotc)(f2c_complex *ret, fortran_int *n, 352 + f2c_complex *sx, fortran_int *incx, 353 + f2c_complex *sy, fortran_int *incy); 354 + extern "C" void 355 + FNAME(zdotc)(f2c_doublecomplex *ret, fortran_int *n, 356 + f2c_doublecomplex *sx, fortran_int *incx, 357 + f2c_doublecomplex *sy, fortran_int *incy); 358 + 359 + extern "C" fortran_int 360 + FNAME(sgemm)(char *transa, char *transb, 361 + fortran_int *m, fortran_int *n, fortran_int *k, 362 + float *alpha, 363 + float *a, fortran_int *lda, 364 + float *b, fortran_int *ldb, 365 + float *beta, 366 + float *c, fortran_int *ldc); 367 + extern "C" fortran_int 368 + FNAME(dgemm)(char *transa, char *transb, 369 + fortran_int *m, fortran_int *n, fortran_int *k, 370 + double *alpha, 371 + double *a, fortran_int *lda, 372 + double *b, fortran_int *ldb, 373 + double *beta, 374 + double *c, fortran_int *ldc); 375 + extern "C" fortran_int 376 + FNAME(cgemm)(char *transa, char *transb, 377 + fortran_int *m, fortran_int *n, fortran_int *k, 378 + f2c_complex *alpha, 379 + f2c_complex *a, fortran_int *lda, 380 + f2c_complex *b, fortran_int *ldb, 381 + f2c_complex *beta, 382 + f2c_complex *c, fortran_int *ldc); 383 + extern "C" fortran_int 384 + FNAME(zgemm)(char *transa, char *transb, 385 + fortran_int *m, fortran_int *n, fortran_int *k, 386 + f2c_doublecomplex *alpha, 387 + f2c_doublecomplex *a, fortran_int *lda, 388 + f2c_doublecomplex *b, fortran_int *ldb, 389 + f2c_doublecomplex *beta, 390 + f2c_doublecomplex *c, fortran_int *ldc); 391 + 392 + 393 + #define LAPACK_T(FUNC) \ 394 + TRACE_TXT("Calling LAPACK ( " # FUNC " )\n"); \ 395 + FNAME(FUNC) 396 + 397 + #define BLAS(FUNC) \ 398 + FNAME(FUNC) 399 + 400 + #define LAPACK(FUNC) \ 401 + FNAME(FUNC) 402 + 403 + 404 + /* 405 + ***************************************************************************** 406 + ** Some handy functions ** 407 + ***************************************************************************** 408 + */ 409 + 410 + static inline int 411 + get_fp_invalid_and_clear(void) 412 + { 413 + int status; 414 + status = npy_clear_floatstatus_barrier((char*)&status); 415 + return !!(status & NPY_FPE_INVALID); 416 + } 417 + 418 + static inline void 419 + set_fp_invalid_or_clear(int error_occurred) 420 + { 421 + if (error_occurred) { 422 + npy_set_floatstatus_invalid(); 423 + } 424 + else { 425 + npy_clear_floatstatus_barrier((char*)&error_occurred); 426 + } 427 + } 428 + 429 + /* 430 + ***************************************************************************** 431 + ** Some handy constants ** 432 + ***************************************************************************** 433 + */ 434 + 435 + #define UMATH_LINALG_MODULE_NAME "_umath_linalg" 436 + 437 + template<typename T> 438 + struct numeric_limits; 439 + 440 + template<> 441 + struct numeric_limits<float> { 442 + static constexpr float one = 1.0f; 443 + static constexpr float zero = 0.0f; 444 + static constexpr float minus_one = -1.0f; 445 + static const float ninf; 446 + static const float nan; 447 + }; 448 + constexpr float numeric_limits<float>::one; 449 + constexpr float numeric_limits<float>::zero; 450 + constexpr float numeric_limits<float>::minus_one; 451 + const float numeric_limits<float>::ninf = -NPY_INFINITYF; 452 + const float numeric_limits<float>::nan = NPY_NANF; 453 + 454 + template<> 455 + struct numeric_limits<double> { 456 + static constexpr double one = 1.0; 457 + static constexpr double zero = 0.0; 458 + static constexpr double minus_one = -1.0; 459 + static const double ninf; 460 + static const double nan; 461 + }; 462 + constexpr double numeric_limits<double>::one; 463 + constexpr double numeric_limits<double>::zero; 464 + constexpr double numeric_limits<double>::minus_one; 465 + const double numeric_limits<double>::ninf = -NPY_INFINITY; 466 + const double numeric_limits<double>::nan = NPY_NAN; 467 + 468 + #if defined(_MSC_VER) && !defined(__INTEL_COMPILER) 469 + template<> 470 + struct numeric_limits<npy_cfloat> { 471 + static constexpr npy_cfloat one = {1.0f, 0.0f}; 472 + static constexpr npy_cfloat zero = {0.0f, 0.0f}; 473 + static constexpr npy_cfloat minus_one = {-1.0f, 0.0f}; 474 + static const npy_cfloat ninf; 475 + static const npy_cfloat nan; 476 + }; 477 + constexpr npy_cfloat numeric_limits<npy_cfloat>::one; 478 + constexpr npy_cfloat numeric_limits<npy_cfloat>::zero; 479 + constexpr npy_cfloat numeric_limits<npy_cfloat>::minus_one; 480 + const npy_cfloat numeric_limits<npy_cfloat>::ninf = {-NPY_INFINITYF, 0.0f}; 481 + const npy_cfloat numeric_limits<npy_cfloat>::nan = {NPY_NANF, NPY_NANF}; 482 + #else 483 + template<> 484 + struct numeric_limits<npy_cfloat> { 485 + static constexpr npy_cfloat one = 1.0f; 486 + static constexpr npy_cfloat zero = 0.0f; 487 + static constexpr npy_cfloat minus_one = -1.0f; 488 + static const npy_cfloat ninf; 489 + static const npy_cfloat nan; 490 + }; 491 + constexpr npy_cfloat numeric_limits<npy_cfloat>::one; 492 + constexpr npy_cfloat numeric_limits<npy_cfloat>::zero; 493 + constexpr npy_cfloat numeric_limits<npy_cfloat>::minus_one; 494 + const npy_cfloat numeric_limits<npy_cfloat>::ninf = -NPY_INFINITYF; 495 + const npy_cfloat numeric_limits<npy_cfloat>::nan = NPY_NANF; 496 + #endif 497 + 498 + template<> 499 + struct numeric_limits<f2c_complex> { 500 + static constexpr f2c_complex one = {1.0f, 0.0f}; 501 + static constexpr f2c_complex zero = {0.0f, 0.0f}; 502 + static constexpr f2c_complex minus_one = {-1.0f, 0.0f}; 503 + static const f2c_complex ninf; 504 + static const f2c_complex nan; 505 + }; 506 + constexpr f2c_complex numeric_limits<f2c_complex>::one; 507 + constexpr f2c_complex numeric_limits<f2c_complex>::zero; 508 + constexpr f2c_complex numeric_limits<f2c_complex>::minus_one; 509 + const f2c_complex numeric_limits<f2c_complex>::ninf = {-NPY_INFINITYF, 0.0f}; 510 + const f2c_complex numeric_limits<f2c_complex>::nan = {NPY_NANF, NPY_NANF}; 511 + 512 + #if defined(_MSC_VER) && !defined(__INTEL_COMPILER) 513 + template<> 514 + struct numeric_limits<npy_cdouble> { 515 + static constexpr npy_cdouble one = {1.0, 0.0}; 516 + static constexpr npy_cdouble zero = {0.0, 0.0}; 517 + static constexpr npy_cdouble minus_one = {-1.0, 0.0}; 518 + static const npy_cdouble ninf; 519 + static const npy_cdouble nan; 520 + }; 521 + constexpr npy_cdouble numeric_limits<npy_cdouble>::one; 522 + constexpr npy_cdouble numeric_limits<npy_cdouble>::zero; 523 + constexpr npy_cdouble numeric_limits<npy_cdouble>::minus_one; 524 + const npy_cdouble numeric_limits<npy_cdouble>::ninf = {-NPY_INFINITY, 0.0}; 525 + const npy_cdouble numeric_limits<npy_cdouble>::nan = {NPY_NAN, NPY_NAN}; 526 + #else 527 + template<> 528 + struct numeric_limits<npy_cdouble> { 529 + static constexpr npy_cdouble one = 1.0; 530 + static constexpr npy_cdouble zero = 0.0; 531 + static constexpr npy_cdouble minus_one = -1.0; 532 + static const npy_cdouble ninf; 533 + static const npy_cdouble nan; 534 + }; 535 + constexpr npy_cdouble numeric_limits<npy_cdouble>::one; 536 + constexpr npy_cdouble numeric_limits<npy_cdouble>::zero; 537 + constexpr npy_cdouble numeric_limits<npy_cdouble>::minus_one; 538 + const npy_cdouble numeric_limits<npy_cdouble>::ninf = -NPY_INFINITY; 539 + const npy_cdouble numeric_limits<npy_cdouble>::nan = NPY_NAN; 540 + #endif 541 + 542 + template<> 543 + struct numeric_limits<f2c_doublecomplex> { 544 + static constexpr f2c_doublecomplex one = {1.0, 0.0}; 545 + static constexpr f2c_doublecomplex zero = {0.0, 0.0}; 546 + static constexpr f2c_doublecomplex minus_one = {-1.0, 0.0}; 547 + static const f2c_doublecomplex ninf; 548 + static const f2c_doublecomplex nan; 549 + }; 550 + constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::one; 551 + constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::zero; 552 + constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::minus_one; 553 + const f2c_doublecomplex numeric_limits<f2c_doublecomplex>::ninf = {-NPY_INFINITY, 0.0}; 554 + const f2c_doublecomplex numeric_limits<f2c_doublecomplex>::nan = {NPY_NAN, NPY_NAN}; 555 + 556 + /* 557 + ***************************************************************************** 558 + ** Structs used for data rearrangement ** 559 + ***************************************************************************** 560 + */ 561 + 562 + 563 + /* 564 + * this struct contains information about how to linearize a matrix in a local 565 + * buffer so that it can be used by blas functions. All strides are specified 566 + * in bytes and are converted to elements later in type specific functions. 567 + * 568 + * rows: number of rows in the matrix 569 + * columns: number of columns in the matrix 570 + * row_strides: the number bytes between consecutive rows. 571 + * column_strides: the number of bytes between consecutive columns. 572 + * output_lead_dim: BLAS/LAPACK-side leading dimension, in elements 573 + */ 574 + typedef struct linearize_data_struct 575 + { 576 + npy_intp rows; 577 + npy_intp columns; 578 + npy_intp row_strides; 579 + npy_intp column_strides; 580 + npy_intp output_lead_dim; 581 + } LINEARIZE_DATA_t; 582 + 583 + static inline void 584 + init_linearize_data_ex(LINEARIZE_DATA_t *lin_data, 585 + npy_intp rows, 586 + npy_intp columns, 587 + npy_intp row_strides, 588 + npy_intp column_strides, 589 + npy_intp output_lead_dim) 590 + { 591 + lin_data->rows = rows; 592 + lin_data->columns = columns; 593 + lin_data->row_strides = row_strides; 594 + lin_data->column_strides = column_strides; 595 + lin_data->output_lead_dim = output_lead_dim; 596 + } 597 + 598 + static inline void 599 + init_linearize_data(LINEARIZE_DATA_t *lin_data, 600 + npy_intp rows, 601 + npy_intp columns, 602 + npy_intp row_strides, 603 + npy_intp column_strides) 604 + { 605 + init_linearize_data_ex( 606 + lin_data, rows, columns, row_strides, column_strides, columns); 607 + } 608 + 609 + static inline void 610 + dump_ufunc_object(PyUFuncObject* ufunc) 611 + { 612 + TRACE_TXT("\n\n%s '%s' (%d input(s), %d output(s), %d specialization(s).\n", 613 + ufunc->core_enabled? "generalized ufunc" : "scalar ufunc", 614 + ufunc->name, ufunc->nin, ufunc->nout, ufunc->ntypes); 615 + if (ufunc->core_enabled) { 616 + int arg; 617 + int dim; 618 + TRACE_TXT("\t%s (%d dimension(s) detected).\n", 619 + ufunc->core_signature, ufunc->core_num_dim_ix); 620 + 621 + for (arg = 0; arg < ufunc->nargs; arg++){ 622 + int * arg_dim_ix = ufunc->core_dim_ixs + ufunc->core_offsets[arg]; 623 + TRACE_TXT("\t\targ %d (%s) has %d dimension(s): (", 624 + arg, arg < ufunc->nin? "INPUT" : "OUTPUT", 625 + ufunc->core_num_dims[arg]); 626 + for (dim = 0; dim < ufunc->core_num_dims[arg]; dim ++) { 627 + TRACE_TXT(" %d", arg_dim_ix[dim]); 628 + } 629 + TRACE_TXT(" )\n"); 630 + } 631 + } 632 + } 633 + 634 + static inline void 635 + dump_linearize_data(const char* name, const LINEARIZE_DATA_t* params) 636 + { 637 + TRACE_TXT("\n\t%s rows: %zd columns: %zd"\ 638 + "\n\t\trow_strides: %td column_strides: %td"\ 639 + "\n", name, params->rows, params->columns, 640 + params->row_strides, params->column_strides); 641 + } 642 + 643 + static inline void 644 + print(npy_float s) 645 + { 646 + TRACE_TXT(" %8.4f", s); 647 + } 648 + static inline void 649 + print(npy_double d) 650 + { 651 + TRACE_TXT(" %10.6f", d); 652 + } 653 + static inline void 654 + print(npy_cfloat c) 655 + { 656 + float* c_parts = (float*)&c; 657 + TRACE_TXT("(%8.4f, %8.4fj)", c_parts[0], c_parts[1]); 658 + } 659 + static inline void 660 + print(npy_cdouble z) 661 + { 662 + double* z_parts = (double*)&z; 663 + TRACE_TXT("(%8.4f, %8.4fj)", z_parts[0], z_parts[1]); 664 + } 665 + 666 + template<typename typ> 667 + static inline void 668 + dump_matrix(const char* name, 669 + size_t rows, size_t columns, 670 + const typ* ptr) 671 + { 672 + size_t i, j; 673 + 674 + TRACE_TXT("\n%s %p (%zd, %zd)\n", name, ptr, rows, columns); 675 + for (i = 0; i < rows; i++) 676 + { 677 + TRACE_TXT("| "); 678 + for (j = 0; j < columns; j++) 679 + { 680 + print(ptr[j*rows + i]); 681 + TRACE_TXT(", "); 682 + } 683 + TRACE_TXT(" |\n"); 684 + } 685 + } 686 + 687 + 688 + /* 689 + ***************************************************************************** 690 + ** Basics ** 691 + ***************************************************************************** 692 + */ 693 + 694 + static inline fortran_int 695 + fortran_int_min(fortran_int x, fortran_int y) { 696 + return x < y ? x : y; 697 + } 698 + 699 + static inline fortran_int 700 + fortran_int_max(fortran_int x, fortran_int y) { 701 + return x > y ? x : y; 702 + } 703 + 704 + #define INIT_OUTER_LOOP_1 \ 705 + npy_intp dN = *dimensions++;\ 706 + npy_intp N_;\ 707 + npy_intp s0 = *steps++; 708 + 709 + #define INIT_OUTER_LOOP_2 \ 710 + INIT_OUTER_LOOP_1\ 711 + npy_intp s1 = *steps++; 712 + 713 + #define INIT_OUTER_LOOP_3 \ 714 + INIT_OUTER_LOOP_2\ 715 + npy_intp s2 = *steps++; 716 + 717 + #define INIT_OUTER_LOOP_4 \ 718 + INIT_OUTER_LOOP_3\ 719 + npy_intp s3 = *steps++; 720 + 721 + #define INIT_OUTER_LOOP_5 \ 722 + INIT_OUTER_LOOP_4\ 723 + npy_intp s4 = *steps++; 724 + 725 + #define INIT_OUTER_LOOP_6 \ 726 + INIT_OUTER_LOOP_5\ 727 + npy_intp s5 = *steps++; 728 + 729 + #define INIT_OUTER_LOOP_7 \ 730 + INIT_OUTER_LOOP_6\ 731 + npy_intp s6 = *steps++; 732 + 733 + #define BEGIN_OUTER_LOOP_2 \ 734 + for (N_ = 0;\ 735 + N_ < dN;\ 736 + N_++, args[0] += s0,\ 737 + args[1] += s1) { 738 + 739 + #define BEGIN_OUTER_LOOP_3 \ 740 + for (N_ = 0;\ 741 + N_ < dN;\ 742 + N_++, args[0] += s0,\ 743 + args[1] += s1,\ 744 + args[2] += s2) { 745 + 746 + #define BEGIN_OUTER_LOOP_4 \ 747 + for (N_ = 0;\ 748 + N_ < dN;\ 749 + N_++, args[0] += s0,\ 750 + args[1] += s1,\ 751 + args[2] += s2,\ 752 + args[3] += s3) { 753 + 754 + #define BEGIN_OUTER_LOOP_5 \ 755 + for (N_ = 0;\ 756 + N_ < dN;\ 757 + N_++, args[0] += s0,\ 758 + args[1] += s1,\ 759 + args[2] += s2,\ 760 + args[3] += s3,\ 761 + args[4] += s4) { 762 + 763 + #define BEGIN_OUTER_LOOP_6 \ 764 + for (N_ = 0;\ 765 + N_ < dN;\ 766 + N_++, args[0] += s0,\ 767 + args[1] += s1,\ 768 + args[2] += s2,\ 769 + args[3] += s3,\ 770 + args[4] += s4,\ 771 + args[5] += s5) { 772 + 773 + #define BEGIN_OUTER_LOOP_7 \ 774 + for (N_ = 0;\ 775 + N_ < dN;\ 776 + N_++, args[0] += s0,\ 777 + args[1] += s1,\ 778 + args[2] += s2,\ 779 + args[3] += s3,\ 780 + args[4] += s4,\ 781 + args[5] += s5,\ 782 + args[6] += s6) { 783 + 784 + #define END_OUTER_LOOP } 785 + 786 + static inline void 787 + update_pointers(npy_uint8** bases, ptrdiff_t* offsets, size_t count) 788 + { 789 + size_t i; 790 + for (i = 0; i < count; ++i) { 791 + bases[i] += offsets[i]; 792 + } 793 + } 794 + 795 + 796 + /* disable -Wmaybe-uninitialized as there is some code that generate false 797 + positives with this warning 798 + */ 799 + #pragma GCC diagnostic push 800 + #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" 801 + 802 + /* 803 + ***************************************************************************** 804 + ** DISPATCHER FUNCS ** 805 + ***************************************************************************** 806 + */ 807 + static fortran_int copy(fortran_int *n, 808 + float *sx, fortran_int *incx, 809 + float *sy, fortran_int *incy) { return FNAME(scopy)(n, sx, incx, 810 + sy, incy); 811 + } 812 + static fortran_int copy(fortran_int *n, 813 + double *sx, fortran_int *incx, 814 + double *sy, fortran_int *incy) { return FNAME(dcopy)(n, sx, incx, 815 + sy, incy); 816 + } 817 + static fortran_int copy(fortran_int *n, 818 + f2c_complex *sx, fortran_int *incx, 819 + f2c_complex *sy, fortran_int *incy) { return FNAME(ccopy)(n, sx, incx, 820 + sy, incy); 821 + } 822 + static fortran_int copy(fortran_int *n, 823 + f2c_doublecomplex *sx, fortran_int *incx, 824 + f2c_doublecomplex *sy, fortran_int *incy) { return FNAME(zcopy)(n, sx, incx, 825 + sy, incy); 826 + } 827 + 828 + static fortran_int getrf(fortran_int *m, fortran_int *n, float a[], fortran_int 829 + *lda, fortran_int ipiv[], fortran_int *info) { 830 + return LAPACK(sgetrf)(m, n, a, lda, ipiv, info); 831 + } 832 + static fortran_int getrf(fortran_int *m, fortran_int *n, double a[], fortran_int 833 + *lda, fortran_int ipiv[], fortran_int *info) { 834 + return LAPACK(dgetrf)(m, n, a, lda, ipiv, info); 835 + } 836 + static fortran_int getrf(fortran_int *m, fortran_int *n, f2c_complex a[], fortran_int 837 + *lda, fortran_int ipiv[], fortran_int *info) { 838 + return LAPACK(cgetrf)(m, n, a, lda, ipiv, info); 839 + } 840 + static fortran_int getrf(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int 841 + *lda, fortran_int ipiv[], fortran_int *info) { 842 + return LAPACK(zgetrf)(m, n, a, lda, ipiv, info); 843 + } 844 + 845 + /* 846 + ***************************************************************************** 847 + ** HELPER FUNCS ** 848 + ***************************************************************************** 849 + */ 850 + template<typename T> 851 + struct fortran_type { 852 + using type = T; 853 + }; 854 + 855 + template<> struct fortran_type<npy_cfloat> { using type = f2c_complex;}; 856 + template<> struct fortran_type<npy_cdouble> { using type = f2c_doublecomplex;}; 857 + template<typename T> 858 + using fortran_type_t = typename fortran_type<T>::type; 859 + 860 + template<typename T> 861 + struct basetype { 862 + using type = T; 863 + }; 864 + template<> struct basetype<npy_cfloat> { using type = npy_float;}; 865 + template<> struct basetype<npy_cdouble> { using type = npy_double;}; 866 + template<> struct basetype<f2c_complex> { using type = fortran_real;}; 867 + template<> struct basetype<f2c_doublecomplex> { using type = fortran_doublereal;}; 868 + template<typename T> 869 + using basetype_t = typename basetype<T>::type; 870 + 871 + struct scalar_trait {}; 872 + struct complex_trait {}; 873 + template<typename typ> 874 + using dispatch_scalar = typename std::conditional<sizeof(basetype_t<typ>) == sizeof(typ), scalar_trait, complex_trait>::type; 875 + 876 + 877 + /* rearranging of 2D matrices using blas */ 878 + 879 + template<typename typ> 880 + static inline void * 881 + linearize_matrix(typ *dst, 882 + typ *src, 883 + const LINEARIZE_DATA_t* data) 884 + { 885 + using ftyp = fortran_type_t<typ>; 886 + if (dst) { 887 + int i, j; 888 + typ* rv = dst; 889 + fortran_int columns = (fortran_int)data->columns; 890 + fortran_int column_strides = 891 + (fortran_int)(data->column_strides/sizeof(typ)); 892 + fortran_int one = 1; 893 + for (i = 0; i < data->rows; i++) { 894 + if (column_strides > 0) { 895 + copy(&columns, 896 + (ftyp*)src, &column_strides, 897 + (ftyp*)dst, &one); 898 + } 899 + else if (column_strides < 0) { 900 + copy(&columns, 901 + ((ftyp*)src + (columns-1)*column_strides), 902 + &column_strides, 903 + (ftyp*)dst, &one); 904 + } 905 + else { 906 + /* 907 + * Zero stride has undefined behavior in some BLAS 908 + * implementations (e.g. OSX Accelerate), so do it 909 + * manually 910 + */ 911 + for (j = 0; j < columns; ++j) { 912 + memcpy(dst + j, src, sizeof(typ)); 913 + } 914 + } 915 + src += data->row_strides/sizeof(typ); 916 + dst += data->output_lead_dim; 917 + } 918 + return rv; 919 + } else { 920 + return src; 921 + } 922 + } 923 + 924 + template<typename typ> 925 + static inline void * 926 + delinearize_matrix(typ *dst, 927 + typ *src, 928 + const LINEARIZE_DATA_t* data) 929 + { 930 + using ftyp = fortran_type_t<typ>; 931 + 932 + if (src) { 933 + int i; 934 + typ *rv = src; 935 + fortran_int columns = (fortran_int)data->columns; 936 + fortran_int column_strides = 937 + (fortran_int)(data->column_strides/sizeof(typ)); 938 + fortran_int one = 1; 939 + for (i = 0; i < data->rows; i++) { 940 + if (column_strides > 0) { 941 + copy(&columns, 942 + (ftyp*)src, &one, 943 + (ftyp*)dst, &column_strides); 944 + } 945 + else if (column_strides < 0) { 946 + copy(&columns, 947 + (ftyp*)src, &one, 948 + ((ftyp*)dst + (columns-1)*column_strides), 949 + &column_strides); 950 + } 951 + else { 952 + /* 953 + * Zero stride has undefined behavior in some BLAS 954 + * implementations (e.g. OSX Accelerate), so do it 955 + * manually 956 + */ 957 + if (columns > 0) { 958 + memcpy(dst, 959 + src + (columns-1), 960 + sizeof(typ)); 961 + } 962 + } 963 + src += data->output_lead_dim; 964 + dst += data->row_strides/sizeof(typ); 965 + } 966 + 967 + return rv; 968 + } else { 969 + return src; 970 + } 971 + } 972 + 973 + template<typename typ> 974 + static inline void 975 + nan_matrix(typ *dst, const LINEARIZE_DATA_t* data) 976 + { 977 + int i, j; 978 + for (i = 0; i < data->rows; i++) { 979 + typ *cp = dst; 980 + ptrdiff_t cs = data->column_strides/sizeof(typ); 981 + for (j = 0; j < data->columns; ++j) { 982 + *cp = numeric_limits<typ>::nan; 983 + cp += cs; 984 + } 985 + dst += data->row_strides/sizeof(typ); 986 + } 987 + } 988 + 989 + template<typename typ> 990 + static inline void 991 + zero_matrix(typ *dst, const LINEARIZE_DATA_t* data) 992 + { 993 + int i, j; 994 + for (i = 0; i < data->rows; i++) { 995 + typ *cp = dst; 996 + ptrdiff_t cs = data->column_strides/sizeof(typ); 997 + for (j = 0; j < data->columns; ++j) { 998 + *cp = numeric_limits<typ>::zero; 999 + cp += cs; 1000 + } 1001 + dst += data->row_strides/sizeof(typ); 1002 + } 1003 + } 1004 + 1005 + /* identity square matrix generation */ 1006 + template<typename typ> 1007 + static inline void 1008 + identity_matrix(typ *matrix, size_t n) 1009 + { 1010 + size_t i; 1011 + /* in IEEE floating point, zeroes are represented as bitwise 0 */ 1012 + memset((void *)matrix, 0, n*n*sizeof(typ)); 1013 + 1014 + for (i = 0; i < n; ++i) 1015 + { 1016 + *matrix = numeric_limits<typ>::one; 1017 + matrix += n+1; 1018 + } 1019 + } 1020 + 1021 + /* lower/upper triangular matrix using blas (in place) */ 1022 + 1023 + template<typename typ> 1024 + static inline void 1025 + triu_matrix(typ *matrix, size_t n) 1026 + { 1027 + size_t i, j; 1028 + matrix += n; 1029 + for (i = 1; i < n; ++i) { 1030 + for (j = 0; j < i; ++j) { 1031 + matrix[j] = numeric_limits<typ>::zero; 1032 + } 1033 + matrix += n; 1034 + } 1035 + } 1036 + 1037 + 1038 + /* -------------------------------------------------------------------------- */ 1039 + /* Determinants */ 1040 + 1041 + static npy_float npylog(npy_float f) { return npy_logf(f);} 1042 + static npy_double npylog(npy_double d) { return npy_log(d);} 1043 + static npy_float npyexp(npy_float f) { return npy_expf(f);} 1044 + static npy_double npyexp(npy_double d) { return npy_exp(d);} 1045 + 1046 + template<typename typ> 1047 + static inline void 1048 + slogdet_from_factored_diagonal(typ* src, 1049 + fortran_int m, 1050 + typ *sign, 1051 + typ *logdet) 1052 + { 1053 + typ acc_sign = *sign; 1054 + typ acc_logdet = numeric_limits<typ>::zero; 1055 + int i; 1056 + for (i = 0; i < m; i++) { 1057 + typ abs_element = *src; 1058 + if (abs_element < numeric_limits<typ>::zero) { 1059 + acc_sign = -acc_sign; 1060 + abs_element = -abs_element; 1061 + } 1062 + 1063 + acc_logdet += npylog(abs_element); 1064 + src += m+1; 1065 + } 1066 + 1067 + *sign = acc_sign; 1068 + *logdet = acc_logdet; 1069 + } 1070 + 1071 + template<typename typ> 1072 + static inline typ 1073 + det_from_slogdet(typ sign, typ logdet) 1074 + { 1075 + typ result = sign * npyexp(logdet); 1076 + return result; 1077 + } 1078 + 1079 + 1080 + npy_float npyabs(npy_cfloat z) { return npy_cabsf(z);} 1081 + npy_double npyabs(npy_cdouble z) { return npy_cabs(z);} 1082 + 1083 + inline float RE(npy_cfloat *c) { return npy_crealf(*c); } 1084 + inline double RE(npy_cdouble *c) { return npy_creal(*c); } 1085 + #if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE 1086 + inline longdouble_t RE(npy_clongdouble *c) { return npy_creall(*c); } 1087 + #endif 1088 + inline float IM(npy_cfloat *c) { return npy_cimagf(*c); } 1089 + inline double IM(npy_cdouble *c) { return npy_cimag(*c); } 1090 + #if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE 1091 + inline longdouble_t IM(npy_clongdouble *c) { return npy_cimagl(*c); } 1092 + #endif 1093 + inline void SETRE(npy_cfloat *c, float real) { npy_csetrealf(c, real); } 1094 + inline void SETRE(npy_cdouble *c, double real) { npy_csetreal(c, real); } 1095 + #if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE 1096 + inline void SETRE(npy_clongdouble *c, double real) { npy_csetreall(c, real); } 1097 + #endif 1098 + inline void SETIM(npy_cfloat *c, float real) { npy_csetimagf(c, real); } 1099 + inline void SETIM(npy_cdouble *c, double real) { npy_csetimag(c, real); } 1100 + #if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE 1101 + inline void SETIM(npy_clongdouble *c, double real) { npy_csetimagl(c, real); } 1102 + #endif 1103 + 1104 + template<typename typ> 1105 + static inline typ 1106 + mult(typ op1, typ op2) 1107 + { 1108 + typ rv; 1109 + 1110 + SETRE(&rv, RE(&op1)*RE(&op2) - IM(&op1)*IM(&op2)); 1111 + SETIM(&rv, RE(&op1)*IM(&op2) + IM(&op1)*RE(&op2)); 1112 + 1113 + return rv; 1114 + } 1115 + 1116 + 1117 + template<typename typ, typename basetyp> 1118 + static inline void 1119 + slogdet_from_factored_diagonal(typ* src, 1120 + fortran_int m, 1121 + typ *sign, 1122 + basetyp *logdet) 1123 + { 1124 + int i; 1125 + typ sign_acc = *sign; 1126 + basetyp logdet_acc = numeric_limits<basetyp>::zero; 1127 + 1128 + for (i = 0; i < m; i++) 1129 + { 1130 + basetyp abs_element = npyabs(*src); 1131 + typ sign_element; 1132 + SETRE(&sign_element, RE(src) / abs_element); 1133 + SETIM(&sign_element, IM(src) / abs_element); 1134 + 1135 + sign_acc = mult(sign_acc, sign_element); 1136 + logdet_acc += npylog(abs_element); 1137 + src += m + 1; 1138 + } 1139 + 1140 + *sign = sign_acc; 1141 + *logdet = logdet_acc; 1142 + } 1143 + 1144 + template<typename typ, typename basetyp> 1145 + static inline typ 1146 + det_from_slogdet(typ sign, basetyp logdet) 1147 + { 1148 + typ tmp; 1149 + SETRE(&tmp, npyexp(logdet)); 1150 + SETIM(&tmp, numeric_limits<basetyp>::zero); 1151 + return mult(sign, tmp); 1152 + } 1153 + 1154 + 1155 + /* As in the linalg package, the determinant is computed via LU factorization 1156 + * using LAPACK. 1157 + * slogdet computes sign + log(determinant). 1158 + * det computes sign * exp(slogdet). 1159 + */ 1160 + template<typename typ, typename basetyp> 1161 + static inline void 1162 + slogdet_single_element(fortran_int m, 1163 + typ* src, 1164 + fortran_int* pivots, 1165 + typ *sign, 1166 + basetyp *logdet) 1167 + { 1168 + using ftyp = fortran_type_t<typ>; 1169 + fortran_int info = 0; 1170 + fortran_int lda = fortran_int_max(m, 1); 1171 + int i; 1172 + /* note: done in place */ 1173 + getrf(&m, &m, (ftyp*)src, &lda, pivots, &info); 1174 + 1175 + if (info == 0) { 1176 + int change_sign = 0; 1177 + /* note: fortran uses 1 based indexing */ 1178 + for (i = 0; i < m; i++) 1179 + { 1180 + change_sign += (pivots[i] != (i+1)); 1181 + } 1182 + 1183 + *sign = (change_sign % 2)?numeric_limits<typ>::minus_one:numeric_limits<typ>::one; 1184 + slogdet_from_factored_diagonal(src, m, sign, logdet); 1185 + } else { 1186 + /* 1187 + if getrf fails, use 0 as sign and -inf as logdet 1188 + */ 1189 + *sign = numeric_limits<typ>::zero; 1190 + *logdet = numeric_limits<basetyp>::ninf; 1191 + } 1192 + } 1193 + 1194 + template<typename typ, typename basetyp> 1195 + static void 1196 + slogdet(char **args, 1197 + npy_intp const *dimensions, 1198 + npy_intp const *steps, 1199 + void *NPY_UNUSED(func)) 1200 + { 1201 + fortran_int m; 1202 + char *tmp_buff = NULL; 1203 + size_t matrix_size; 1204 + size_t pivot_size; 1205 + size_t safe_m; 1206 + /* notes: 1207 + * matrix will need to be copied always, as factorization in lapack is 1208 + * made inplace 1209 + * matrix will need to be in column-major order, as expected by lapack 1210 + * code (fortran) 1211 + * always a square matrix 1212 + * need to allocate memory for both, matrix_buffer and pivot buffer 1213 + */ 1214 + INIT_OUTER_LOOP_3 1215 + m = (fortran_int) dimensions[0]; 1216 + /* avoid empty malloc (buffers likely unused) and ensure m is `size_t` */ 1217 + safe_m = m != 0 ? m : 1; 1218 + matrix_size = safe_m * safe_m * sizeof(typ); 1219 + pivot_size = safe_m * sizeof(fortran_int); 1220 + tmp_buff = (char *)malloc(matrix_size + pivot_size); 1221 + 1222 + if (tmp_buff) { 1223 + LINEARIZE_DATA_t lin_data; 1224 + /* swapped steps to get matrix in FORTRAN order */ 1225 + init_linearize_data(&lin_data, m, m, steps[1], steps[0]); 1226 + BEGIN_OUTER_LOOP_3 1227 + linearize_matrix((typ*)tmp_buff, (typ*)args[0], &lin_data); 1228 + slogdet_single_element(m, 1229 + (typ*)tmp_buff, 1230 + (fortran_int*)(tmp_buff+matrix_size), 1231 + (typ*)args[1], 1232 + (basetyp*)args[2]); 1233 + END_OUTER_LOOP 1234 + 1235 + free(tmp_buff); 1236 + } 1237 + else { 1238 + /* TODO: Requires use of new ufunc API to indicate error return */ 1239 + NPY_ALLOW_C_API_DEF 1240 + NPY_ALLOW_C_API; 1241 + PyErr_NoMemory(); 1242 + NPY_DISABLE_C_API; 1243 + } 1244 + } 1245 + 1246 + template<typename typ, typename basetyp> 1247 + static void 1248 + det(char **args, 1249 + npy_intp const *dimensions, 1250 + npy_intp const *steps, 1251 + void *NPY_UNUSED(func)) 1252 + { 1253 + fortran_int m; 1254 + char *tmp_buff; 1255 + size_t matrix_size; 1256 + size_t pivot_size; 1257 + size_t safe_m; 1258 + /* notes: 1259 + * matrix will need to be copied always, as factorization in lapack is 1260 + * made inplace 1261 + * matrix will need to be in column-major order, as expected by lapack 1262 + * code (fortran) 1263 + * always a square matrix 1264 + * need to allocate memory for both, matrix_buffer and pivot buffer 1265 + */ 1266 + INIT_OUTER_LOOP_2 1267 + m = (fortran_int) dimensions[0]; 1268 + /* avoid empty malloc (buffers likely unused) and ensure m is `size_t` */ 1269 + safe_m = m != 0 ? m : 1; 1270 + matrix_size = safe_m * safe_m * sizeof(typ); 1271 + pivot_size = safe_m * sizeof(fortran_int); 1272 + tmp_buff = (char *)malloc(matrix_size + pivot_size); 1273 + 1274 + if (tmp_buff) { 1275 + LINEARIZE_DATA_t lin_data; 1276 + typ sign; 1277 + basetyp logdet; 1278 + /* swapped steps to get matrix in FORTRAN order */ 1279 + init_linearize_data(&lin_data, m, m, steps[1], steps[0]); 1280 + 1281 + BEGIN_OUTER_LOOP_2 1282 + linearize_matrix((typ*)tmp_buff, (typ*)args[0], &lin_data); 1283 + slogdet_single_element(m, 1284 + (typ*)tmp_buff, 1285 + (fortran_int*)(tmp_buff + matrix_size), 1286 + &sign, 1287 + &logdet); 1288 + *(typ *)args[1] = det_from_slogdet(sign, logdet); 1289 + END_OUTER_LOOP 1290 + 1291 + free(tmp_buff); 1292 + } 1293 + else { 1294 + /* TODO: Requires use of new ufunc API to indicate error return */ 1295 + NPY_ALLOW_C_API_DEF 1296 + NPY_ALLOW_C_API; 1297 + PyErr_NoMemory(); 1298 + NPY_DISABLE_C_API; 1299 + } 1300 + } 1301 + 1302 + 1303 + /* -------------------------------------------------------------------------- */ 1304 + /* Eigh family */ 1305 + 1306 + template<typename typ> 1307 + struct EIGH_PARAMS_t { 1308 + typ *A; /* matrix */ 1309 + basetype_t<typ> *W; /* eigenvalue vector */ 1310 + typ *WORK; /* main work buffer */ 1311 + basetype_t<typ> *RWORK; /* secondary work buffer (for complex versions) */ 1312 + fortran_int *IWORK; 1313 + fortran_int N; 1314 + fortran_int LWORK; 1315 + fortran_int LRWORK; 1316 + fortran_int LIWORK; 1317 + char JOBZ; 1318 + char UPLO; 1319 + fortran_int LDA; 1320 + } ; 1321 + 1322 + static inline fortran_int 1323 + call_evd(EIGH_PARAMS_t<npy_float> *params) 1324 + { 1325 + fortran_int rv; 1326 + LAPACK(ssyevd)(&params->JOBZ, &params->UPLO, &params->N, 1327 + params->A, &params->LDA, params->W, 1328 + params->WORK, &params->LWORK, 1329 + params->IWORK, &params->LIWORK, 1330 + &rv); 1331 + return rv; 1332 + } 1333 + static inline fortran_int 1334 + call_evd(EIGH_PARAMS_t<npy_double> *params) 1335 + { 1336 + fortran_int rv; 1337 + LAPACK(dsyevd)(&params->JOBZ, &params->UPLO, &params->N, 1338 + params->A, &params->LDA, params->W, 1339 + params->WORK, &params->LWORK, 1340 + params->IWORK, &params->LIWORK, 1341 + &rv); 1342 + return rv; 1343 + } 1344 + 1345 + 1346 + /* 1347 + * Initialize the parameters to use in for the lapack function _syevd 1348 + * Handles buffer allocation 1349 + */ 1350 + template<typename typ> 1351 + static inline int 1352 + init_evd(EIGH_PARAMS_t<typ>* params, char JOBZ, char UPLO, 1353 + fortran_int N, scalar_trait) 1354 + { 1355 + npy_uint8 *mem_buff = NULL; 1356 + npy_uint8 *mem_buff2 = NULL; 1357 + fortran_int lwork; 1358 + fortran_int liwork; 1359 + npy_uint8 *a, *w, *work, *iwork; 1360 + size_t safe_N = N; 1361 + size_t alloc_size = safe_N * (safe_N + 1) * sizeof(typ); 1362 + fortran_int lda = fortran_int_max(N, 1); 1363 + 1364 + mem_buff = (npy_uint8 *)malloc(alloc_size); 1365 + 1366 + if (!mem_buff) { 1367 + goto error; 1368 + } 1369 + a = mem_buff; 1370 + w = mem_buff + safe_N * safe_N * sizeof(typ); 1371 + 1372 + params->A = (typ*)a; 1373 + params->W = (typ*)w; 1374 + params->RWORK = NULL; /* unused */ 1375 + params->N = N; 1376 + params->LRWORK = 0; /* unused */ 1377 + params->JOBZ = JOBZ; 1378 + params->UPLO = UPLO; 1379 + params->LDA = lda; 1380 + 1381 + /* Work size query */ 1382 + { 1383 + typ query_work_size; 1384 + fortran_int query_iwork_size; 1385 + 1386 + params->LWORK = -1; 1387 + params->LIWORK = -1; 1388 + params->WORK = &query_work_size; 1389 + params->IWORK = &query_iwork_size; 1390 + 1391 + if (call_evd(params) != 0) { 1392 + goto error; 1393 + } 1394 + 1395 + lwork = (fortran_int)query_work_size; 1396 + liwork = query_iwork_size; 1397 + } 1398 + 1399 + mem_buff2 = (npy_uint8 *)malloc(lwork*sizeof(typ) + liwork*sizeof(fortran_int)); 1400 + if (!mem_buff2) { 1401 + goto error; 1402 + } 1403 + 1404 + work = mem_buff2; 1405 + iwork = mem_buff2 + lwork*sizeof(typ); 1406 + 1407 + params->LWORK = lwork; 1408 + params->WORK = (typ*)work; 1409 + params->LIWORK = liwork; 1410 + params->IWORK = (fortran_int*)iwork; 1411 + 1412 + return 1; 1413 + 1414 + error: 1415 + /* something failed */ 1416 + memset(params, 0, sizeof(*params)); 1417 + free(mem_buff2); 1418 + free(mem_buff); 1419 + 1420 + return 0; 1421 + } 1422 + 1423 + 1424 + static inline fortran_int 1425 + call_evd(EIGH_PARAMS_t<npy_cfloat> *params) 1426 + { 1427 + fortran_int rv; 1428 + LAPACK(cheevd)(&params->JOBZ, &params->UPLO, &params->N, 1429 + (fortran_type_t<npy_cfloat>*)params->A, &params->LDA, params->W, 1430 + (fortran_type_t<npy_cfloat>*)params->WORK, &params->LWORK, 1431 + params->RWORK, &params->LRWORK, 1432 + params->IWORK, &params->LIWORK, 1433 + &rv); 1434 + return rv; 1435 + } 1436 + 1437 + static inline fortran_int 1438 + call_evd(EIGH_PARAMS_t<npy_cdouble> *params) 1439 + { 1440 + fortran_int rv; 1441 + LAPACK(zheevd)(&params->JOBZ, &params->UPLO, &params->N, 1442 + (fortran_type_t<npy_cdouble>*)params->A, &params->LDA, params->W, 1443 + (fortran_type_t<npy_cdouble>*)params->WORK, &params->LWORK, 1444 + params->RWORK, &params->LRWORK, 1445 + params->IWORK, &params->LIWORK, 1446 + &rv); 1447 + return rv; 1448 + } 1449 + 1450 + template<typename typ> 1451 + static inline int 1452 + init_evd(EIGH_PARAMS_t<typ> *params, 1453 + char JOBZ, 1454 + char UPLO, 1455 + fortran_int N, complex_trait) 1456 + { 1457 + using basetyp = basetype_t<typ>; 1458 + using ftyp = fortran_type_t<typ>; 1459 + using fbasetyp = fortran_type_t<basetyp>; 1460 + npy_uint8 *mem_buff = NULL; 1461 + npy_uint8 *mem_buff2 = NULL; 1462 + fortran_int lwork; 1463 + fortran_int lrwork; 1464 + fortran_int liwork; 1465 + npy_uint8 *a, *w, *work, *rwork, *iwork; 1466 + size_t safe_N = N; 1467 + fortran_int lda = fortran_int_max(N, 1); 1468 + 1469 + mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(typ) + 1470 + safe_N * sizeof(basetyp)); 1471 + if (!mem_buff) { 1472 + goto error; 1473 + } 1474 + a = mem_buff; 1475 + w = mem_buff + safe_N * safe_N * sizeof(typ); 1476 + 1477 + params->A = (typ*)a; 1478 + params->W = (basetyp*)w; 1479 + params->N = N; 1480 + params->JOBZ = JOBZ; 1481 + params->UPLO = UPLO; 1482 + params->LDA = lda; 1483 + 1484 + /* Work size query */ 1485 + { 1486 + ftyp query_work_size; 1487 + fbasetyp query_rwork_size; 1488 + fortran_int query_iwork_size; 1489 + 1490 + params->LWORK = -1; 1491 + params->LRWORK = -1; 1492 + params->LIWORK = -1; 1493 + params->WORK = (typ*)&query_work_size; 1494 + params->RWORK = (basetyp*)&query_rwork_size; 1495 + params->IWORK = &query_iwork_size; 1496 + 1497 + if (call_evd(params) != 0) { 1498 + goto error; 1499 + } 1500 + 1501 + lwork = (fortran_int)*(fbasetyp*)&query_work_size; 1502 + lrwork = (fortran_int)query_rwork_size; 1503 + liwork = query_iwork_size; 1504 + } 1505 + 1506 + mem_buff2 = (npy_uint8 *)malloc(lwork*sizeof(typ) + 1507 + lrwork*sizeof(basetyp) + 1508 + liwork*sizeof(fortran_int)); 1509 + if (!mem_buff2) { 1510 + goto error; 1511 + } 1512 + 1513 + work = mem_buff2; 1514 + rwork = work + lwork*sizeof(typ); 1515 + iwork = rwork + lrwork*sizeof(basetyp); 1516 + 1517 + params->WORK = (typ*)work; 1518 + params->RWORK = (basetyp*)rwork; 1519 + params->IWORK = (fortran_int*)iwork; 1520 + params->LWORK = lwork; 1521 + params->LRWORK = lrwork; 1522 + params->LIWORK = liwork; 1523 + 1524 + return 1; 1525 + 1526 + /* something failed */ 1527 + error: 1528 + memset(params, 0, sizeof(*params)); 1529 + free(mem_buff2); 1530 + free(mem_buff); 1531 + 1532 + return 0; 1533 + } 1534 + 1535 + /* 1536 + * (M, M)->(M,)(M, M) 1537 + * dimensions[1] -> M 1538 + * args[0] -> A[in] 1539 + * args[1] -> W 1540 + * args[2] -> A[out] 1541 + */ 1542 + 1543 + template<typename typ> 1544 + static inline void 1545 + release_evd(EIGH_PARAMS_t<typ> *params) 1546 + { 1547 + /* allocated memory in A and WORK */ 1548 + free(params->A); 1549 + free(params->WORK); 1550 + memset(params, 0, sizeof(*params)); 1551 + } 1552 + 1553 + 1554 + template<typename typ> 1555 + static inline void 1556 + eigh_wrapper(char JOBZ, 1557 + char UPLO, 1558 + char**args, 1559 + npy_intp const *dimensions, 1560 + npy_intp const *steps) 1561 + { 1562 + using basetyp = basetype_t<typ>; 1563 + ptrdiff_t outer_steps[3]; 1564 + size_t iter; 1565 + size_t outer_dim = *dimensions++; 1566 + size_t op_count = (JOBZ=='N')?2:3; 1567 + EIGH_PARAMS_t<typ> eigh_params; 1568 + int error_occurred = get_fp_invalid_and_clear(); 1569 + 1570 + for (iter = 0; iter < op_count; ++iter) { 1571 + outer_steps[iter] = (ptrdiff_t) steps[iter]; 1572 + } 1573 + steps += op_count; 1574 + 1575 + if (init_evd(&eigh_params, 1576 + JOBZ, 1577 + UPLO, 1578 + (fortran_int)dimensions[0], dispatch_scalar<typ>())) { 1579 + LINEARIZE_DATA_t matrix_in_ld; 1580 + LINEARIZE_DATA_t eigenvectors_out_ld; 1581 + LINEARIZE_DATA_t eigenvalues_out_ld; 1582 + 1583 + init_linearize_data(&matrix_in_ld, 1584 + eigh_params.N, eigh_params.N, 1585 + steps[1], steps[0]); 1586 + init_linearize_data(&eigenvalues_out_ld, 1587 + 1, eigh_params.N, 1588 + 0, steps[2]); 1589 + if ('V' == eigh_params.JOBZ) { 1590 + init_linearize_data(&eigenvectors_out_ld, 1591 + eigh_params.N, eigh_params.N, 1592 + steps[4], steps[3]); 1593 + } 1594 + 1595 + for (iter = 0; iter < outer_dim; ++iter) { 1596 + int not_ok; 1597 + /* copy the matrix in */ 1598 + linearize_matrix((typ*)eigh_params.A, (typ*)args[0], &matrix_in_ld); 1599 + not_ok = call_evd(&eigh_params); 1600 + if (!not_ok) { 1601 + /* lapack ok, copy result out */ 1602 + delinearize_matrix((basetyp*)args[1], 1603 + (basetyp*)eigh_params.W, 1604 + &eigenvalues_out_ld); 1605 + 1606 + if ('V' == eigh_params.JOBZ) { 1607 + delinearize_matrix((typ*)args[2], 1608 + (typ*)eigh_params.A, 1609 + &eigenvectors_out_ld); 1610 + } 1611 + } else { 1612 + /* lapack fail, set result to nan */ 1613 + error_occurred = 1; 1614 + nan_matrix((basetyp*)args[1], &eigenvalues_out_ld); 1615 + if ('V' == eigh_params.JOBZ) { 1616 + nan_matrix((typ*)args[2], &eigenvectors_out_ld); 1617 + } 1618 + } 1619 + update_pointers((npy_uint8**)args, outer_steps, op_count); 1620 + } 1621 + 1622 + release_evd(&eigh_params); 1623 + } 1624 + 1625 + set_fp_invalid_or_clear(error_occurred); 1626 + } 1627 + 1628 + 1629 + template<typename typ> 1630 + static void 1631 + eighlo(char **args, 1632 + npy_intp const *dimensions, 1633 + npy_intp const *steps, 1634 + void *NPY_UNUSED(func)) 1635 + { 1636 + eigh_wrapper<typ>('V', 'L', args, dimensions, steps); 1637 + } 1638 + 1639 + template<typename typ> 1640 + static void 1641 + eighup(char **args, 1642 + npy_intp const *dimensions, 1643 + npy_intp const *steps, 1644 + void* NPY_UNUSED(func)) 1645 + { 1646 + eigh_wrapper<typ>('V', 'U', args, dimensions, steps); 1647 + } 1648 + 1649 + template<typename typ> 1650 + static void 1651 + eigvalshlo(char **args, 1652 + npy_intp const *dimensions, 1653 + npy_intp const *steps, 1654 + void* NPY_UNUSED(func)) 1655 + { 1656 + eigh_wrapper<typ>('N', 'L', args, dimensions, steps); 1657 + } 1658 + 1659 + template<typename typ> 1660 + static void 1661 + eigvalshup(char **args, 1662 + npy_intp const *dimensions, 1663 + npy_intp const *steps, 1664 + void* NPY_UNUSED(func)) 1665 + { 1666 + eigh_wrapper<typ>('N', 'U', args, dimensions, steps); 1667 + } 1668 + 1669 + /* -------------------------------------------------------------------------- */ 1670 + /* Solve family (includes inv) */ 1671 + 1672 + template<typename typ> 1673 + struct GESV_PARAMS_t 1674 + { 1675 + typ *A; /* A is (N, N) of base type */ 1676 + typ *B; /* B is (N, NRHS) of base type */ 1677 + fortran_int * IPIV; /* IPIV is (N) */ 1678 + 1679 + fortran_int N; 1680 + fortran_int NRHS; 1681 + fortran_int LDA; 1682 + fortran_int LDB; 1683 + }; 1684 + 1685 + static inline fortran_int 1686 + call_gesv(GESV_PARAMS_t<fortran_real> *params) 1687 + { 1688 + fortran_int rv; 1689 + LAPACK(sgesv)(&params->N, &params->NRHS, 1690 + params->A, &params->LDA, 1691 + params->IPIV, 1692 + params->B, &params->LDB, 1693 + &rv); 1694 + return rv; 1695 + } 1696 + 1697 + static inline fortran_int 1698 + call_gesv(GESV_PARAMS_t<fortran_doublereal> *params) 1699 + { 1700 + fortran_int rv; 1701 + LAPACK(dgesv)(&params->N, &params->NRHS, 1702 + params->A, &params->LDA, 1703 + params->IPIV, 1704 + params->B, &params->LDB, 1705 + &rv); 1706 + return rv; 1707 + } 1708 + 1709 + static inline fortran_int 1710 + call_gesv(GESV_PARAMS_t<fortran_complex> *params) 1711 + { 1712 + fortran_int rv; 1713 + LAPACK(cgesv)(&params->N, &params->NRHS, 1714 + params->A, &params->LDA, 1715 + params->IPIV, 1716 + params->B, &params->LDB, 1717 + &rv); 1718 + return rv; 1719 + } 1720 + 1721 + static inline fortran_int 1722 + call_gesv(GESV_PARAMS_t<fortran_doublecomplex> *params) 1723 + { 1724 + fortran_int rv; 1725 + LAPACK(zgesv)(&params->N, &params->NRHS, 1726 + params->A, &params->LDA, 1727 + params->IPIV, 1728 + params->B, &params->LDB, 1729 + &rv); 1730 + return rv; 1731 + } 1732 + 1733 + 1734 + /* 1735 + * Initialize the parameters to use in for the lapack function _heev 1736 + * Handles buffer allocation 1737 + */ 1738 + template<typename ftyp> 1739 + static inline int 1740 + init_gesv(GESV_PARAMS_t<ftyp> *params, fortran_int N, fortran_int NRHS) 1741 + { 1742 + npy_uint8 *mem_buff = NULL; 1743 + npy_uint8 *a, *b, *ipiv; 1744 + size_t safe_N = N; 1745 + size_t safe_NRHS = NRHS; 1746 + fortran_int ld = fortran_int_max(N, 1); 1747 + mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(ftyp) + 1748 + safe_N * safe_NRHS*sizeof(ftyp) + 1749 + safe_N * sizeof(fortran_int)); 1750 + if (!mem_buff) { 1751 + goto error; 1752 + } 1753 + a = mem_buff; 1754 + b = a + safe_N * safe_N * sizeof(ftyp); 1755 + ipiv = b + safe_N * safe_NRHS * sizeof(ftyp); 1756 + 1757 + params->A = (ftyp*)a; 1758 + params->B = (ftyp*)b; 1759 + params->IPIV = (fortran_int*)ipiv; 1760 + params->N = N; 1761 + params->NRHS = NRHS; 1762 + params->LDA = ld; 1763 + params->LDB = ld; 1764 + 1765 + return 1; 1766 + error: 1767 + free(mem_buff); 1768 + memset(params, 0, sizeof(*params)); 1769 + 1770 + return 0; 1771 + } 1772 + 1773 + template<typename ftyp> 1774 + static inline void 1775 + release_gesv(GESV_PARAMS_t<ftyp> *params) 1776 + { 1777 + /* memory block base is in A */ 1778 + free(params->A); 1779 + memset(params, 0, sizeof(*params)); 1780 + } 1781 + 1782 + template<typename typ> 1783 + static void 1784 + solve(char **args, npy_intp const *dimensions, npy_intp const *steps, 1785 + void *NPY_UNUSED(func)) 1786 + { 1787 + using ftyp = fortran_type_t<typ>; 1788 + GESV_PARAMS_t<ftyp> params; 1789 + fortran_int n, nrhs; 1790 + int error_occurred = get_fp_invalid_and_clear(); 1791 + INIT_OUTER_LOOP_3 1792 + 1793 + n = (fortran_int)dimensions[0]; 1794 + nrhs = (fortran_int)dimensions[1]; 1795 + if (init_gesv(&params, n, nrhs)) { 1796 + LINEARIZE_DATA_t a_in, b_in, r_out; 1797 + 1798 + init_linearize_data(&a_in, n, n, steps[1], steps[0]); 1799 + init_linearize_data(&b_in, nrhs, n, steps[3], steps[2]); 1800 + init_linearize_data(&r_out, nrhs, n, steps[5], steps[4]); 1801 + 1802 + BEGIN_OUTER_LOOP_3 1803 + int not_ok; 1804 + linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 1805 + linearize_matrix((typ*)params.B, (typ*)args[1], &b_in); 1806 + not_ok =call_gesv(&params); 1807 + if (!not_ok) { 1808 + delinearize_matrix((typ*)args[2], (typ*)params.B, &r_out); 1809 + } else { 1810 + error_occurred = 1; 1811 + nan_matrix((typ*)args[2], &r_out); 1812 + } 1813 + END_OUTER_LOOP 1814 + 1815 + release_gesv(&params); 1816 + } 1817 + 1818 + set_fp_invalid_or_clear(error_occurred); 1819 + } 1820 + 1821 + 1822 + template<typename typ> 1823 + static void 1824 + solve1(char **args, npy_intp const *dimensions, npy_intp const *steps, 1825 + void *NPY_UNUSED(func)) 1826 + { 1827 + using ftyp = fortran_type_t<typ>; 1828 + GESV_PARAMS_t<ftyp> params; 1829 + int error_occurred = get_fp_invalid_and_clear(); 1830 + fortran_int n; 1831 + INIT_OUTER_LOOP_3 1832 + 1833 + n = (fortran_int)dimensions[0]; 1834 + if (init_gesv(&params, n, 1)) { 1835 + LINEARIZE_DATA_t a_in, b_in, r_out; 1836 + init_linearize_data(&a_in, n, n, steps[1], steps[0]); 1837 + init_linearize_data(&b_in, 1, n, 1, steps[2]); 1838 + init_linearize_data(&r_out, 1, n, 1, steps[3]); 1839 + 1840 + BEGIN_OUTER_LOOP_3 1841 + int not_ok; 1842 + linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 1843 + linearize_matrix((typ*)params.B, (typ*)args[1], &b_in); 1844 + not_ok = call_gesv(&params); 1845 + if (!not_ok) { 1846 + delinearize_matrix((typ*)args[2], (typ*)params.B, &r_out); 1847 + } else { 1848 + error_occurred = 1; 1849 + nan_matrix((typ*)args[2], &r_out); 1850 + } 1851 + END_OUTER_LOOP 1852 + 1853 + release_gesv(&params); 1854 + } 1855 + 1856 + set_fp_invalid_or_clear(error_occurred); 1857 + } 1858 + 1859 + template<typename typ> 1860 + static void 1861 + inv(char **args, npy_intp const *dimensions, npy_intp const *steps, 1862 + void *NPY_UNUSED(func)) 1863 + { 1864 + using ftyp = fortran_type_t<typ>; 1865 + GESV_PARAMS_t<ftyp> params; 1866 + fortran_int n; 1867 + int error_occurred = get_fp_invalid_and_clear(); 1868 + INIT_OUTER_LOOP_2 1869 + 1870 + n = (fortran_int)dimensions[0]; 1871 + if (init_gesv(&params, n, n)) { 1872 + LINEARIZE_DATA_t a_in, r_out; 1873 + init_linearize_data(&a_in, n, n, steps[1], steps[0]); 1874 + init_linearize_data(&r_out, n, n, steps[3], steps[2]); 1875 + 1876 + BEGIN_OUTER_LOOP_2 1877 + int not_ok; 1878 + linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 1879 + identity_matrix((typ*)params.B, n); 1880 + not_ok = call_gesv(&params); 1881 + if (!not_ok) { 1882 + delinearize_matrix((typ*)args[1], (typ*)params.B, &r_out); 1883 + } else { 1884 + error_occurred = 1; 1885 + nan_matrix((typ*)args[1], &r_out); 1886 + } 1887 + END_OUTER_LOOP 1888 + 1889 + release_gesv(&params); 1890 + } 1891 + 1892 + set_fp_invalid_or_clear(error_occurred); 1893 + } 1894 + 1895 + 1896 + /* -------------------------------------------------------------------------- */ 1897 + /* Cholesky decomposition */ 1898 + 1899 + template<typename typ> 1900 + struct POTR_PARAMS_t 1901 + { 1902 + typ *A; 1903 + fortran_int N; 1904 + fortran_int LDA; 1905 + char UPLO; 1906 + }; 1907 + 1908 + 1909 + static inline fortran_int 1910 + call_potrf(POTR_PARAMS_t<fortran_real> *params) 1911 + { 1912 + fortran_int rv; 1913 + LAPACK(spotrf)(&params->UPLO, 1914 + &params->N, params->A, &params->LDA, 1915 + &rv); 1916 + return rv; 1917 + } 1918 + 1919 + static inline fortran_int 1920 + call_potrf(POTR_PARAMS_t<fortran_doublereal> *params) 1921 + { 1922 + fortran_int rv; 1923 + LAPACK(dpotrf)(&params->UPLO, 1924 + &params->N, params->A, &params->LDA, 1925 + &rv); 1926 + return rv; 1927 + } 1928 + 1929 + static inline fortran_int 1930 + call_potrf(POTR_PARAMS_t<fortran_complex> *params) 1931 + { 1932 + fortran_int rv; 1933 + LAPACK(cpotrf)(&params->UPLO, 1934 + &params->N, params->A, &params->LDA, 1935 + &rv); 1936 + return rv; 1937 + } 1938 + 1939 + static inline fortran_int 1940 + call_potrf(POTR_PARAMS_t<fortran_doublecomplex> *params) 1941 + { 1942 + fortran_int rv; 1943 + LAPACK(zpotrf)(&params->UPLO, 1944 + &params->N, params->A, &params->LDA, 1945 + &rv); 1946 + return rv; 1947 + } 1948 + 1949 + template<typename ftyp> 1950 + static inline int 1951 + init_potrf(POTR_PARAMS_t<ftyp> *params, char UPLO, fortran_int N) 1952 + { 1953 + npy_uint8 *mem_buff = NULL; 1954 + npy_uint8 *a; 1955 + size_t safe_N = N; 1956 + fortran_int lda = fortran_int_max(N, 1); 1957 + 1958 + mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(ftyp)); 1959 + if (!mem_buff) { 1960 + goto error; 1961 + } 1962 + 1963 + a = mem_buff; 1964 + 1965 + params->A = (ftyp*)a; 1966 + params->N = N; 1967 + params->LDA = lda; 1968 + params->UPLO = UPLO; 1969 + 1970 + return 1; 1971 + error: 1972 + free(mem_buff); 1973 + memset(params, 0, sizeof(*params)); 1974 + 1975 + return 0; 1976 + } 1977 + 1978 + template<typename ftyp> 1979 + static inline void 1980 + release_potrf(POTR_PARAMS_t<ftyp> *params) 1981 + { 1982 + /* memory block base in A */ 1983 + free(params->A); 1984 + memset(params, 0, sizeof(*params)); 1985 + } 1986 + 1987 + template<typename typ> 1988 + static void 1989 + cholesky(char uplo, char **args, npy_intp const *dimensions, npy_intp const *steps) 1990 + { 1991 + using ftyp = fortran_type_t<typ>; 1992 + POTR_PARAMS_t<ftyp> params; 1993 + int error_occurred = get_fp_invalid_and_clear(); 1994 + fortran_int n; 1995 + INIT_OUTER_LOOP_2 1996 + 1997 + assert(uplo == 'L'); 1998 + 1999 + n = (fortran_int)dimensions[0]; 2000 + if (init_potrf(&params, uplo, n)) { 2001 + LINEARIZE_DATA_t a_in, r_out; 2002 + init_linearize_data(&a_in, n, n, steps[1], steps[0]); 2003 + init_linearize_data(&r_out, n, n, steps[3], steps[2]); 2004 + BEGIN_OUTER_LOOP_2 2005 + int not_ok; 2006 + linearize_matrix(params.A, (ftyp*)args[0], &a_in); 2007 + not_ok = call_potrf(&params); 2008 + if (!not_ok) { 2009 + triu_matrix(params.A, params.N); 2010 + delinearize_matrix((ftyp*)args[1], params.A, &r_out); 2011 + } else { 2012 + error_occurred = 1; 2013 + nan_matrix((ftyp*)args[1], &r_out); 2014 + } 2015 + END_OUTER_LOOP 2016 + release_potrf(&params); 2017 + } 2018 + 2019 + set_fp_invalid_or_clear(error_occurred); 2020 + } 2021 + 2022 + template<typename typ> 2023 + static void 2024 + cholesky_lo(char **args, npy_intp const *dimensions, npy_intp const *steps, 2025 + void *NPY_UNUSED(func)) 2026 + { 2027 + cholesky<typ>('L', args, dimensions, steps); 2028 + } 2029 + 2030 + /* -------------------------------------------------------------------------- */ 2031 + /* eig family */ 2032 + 2033 + template<typename typ> 2034 + struct GEEV_PARAMS_t { 2035 + typ *A; 2036 + basetype_t<typ> *WR; /* RWORK in complex versions, REAL W buffer for (sd)geev*/ 2037 + typ *WI; 2038 + typ *VLR; /* REAL VL buffers for _geev where _ is s, d */ 2039 + typ *VRR; /* REAL VR buffers for _geev where _ is s, d */ 2040 + typ *WORK; 2041 + typ *W; /* final w */ 2042 + typ *VL; /* final vl */ 2043 + typ *VR; /* final vr */ 2044 + 2045 + fortran_int N; 2046 + fortran_int LDA; 2047 + fortran_int LDVL; 2048 + fortran_int LDVR; 2049 + fortran_int LWORK; 2050 + 2051 + char JOBVL; 2052 + char JOBVR; 2053 + }; 2054 + 2055 + template<typename typ> 2056 + static inline void 2057 + dump_geev_params(const char *name, GEEV_PARAMS_t<typ>* params) 2058 + { 2059 + TRACE_TXT("\n%s\n" 2060 + 2061 + "\t%10s: %p\n"\ 2062 + "\t%10s: %p\n"\ 2063 + "\t%10s: %p\n"\ 2064 + "\t%10s: %p\n"\ 2065 + "\t%10s: %p\n"\ 2066 + "\t%10s: %p\n"\ 2067 + "\t%10s: %p\n"\ 2068 + "\t%10s: %p\n"\ 2069 + "\t%10s: %p\n"\ 2070 + 2071 + "\t%10s: %d\n"\ 2072 + "\t%10s: %d\n"\ 2073 + "\t%10s: %d\n"\ 2074 + "\t%10s: %d\n"\ 2075 + "\t%10s: %d\n"\ 2076 + 2077 + "\t%10s: %c\n"\ 2078 + "\t%10s: %c\n", 2079 + 2080 + name, 2081 + 2082 + "A", params->A, 2083 + "WR", params->WR, 2084 + "WI", params->WI, 2085 + "VLR", params->VLR, 2086 + "VRR", params->VRR, 2087 + "WORK", params->WORK, 2088 + "W", params->W, 2089 + "VL", params->VL, 2090 + "VR", params->VR, 2091 + 2092 + "N", (int)params->N, 2093 + "LDA", (int)params->LDA, 2094 + "LDVL", (int)params->LDVL, 2095 + "LDVR", (int)params->LDVR, 2096 + "LWORK", (int)params->LWORK, 2097 + 2098 + "JOBVL", params->JOBVL, 2099 + "JOBVR", params->JOBVR); 2100 + } 2101 + 2102 + static inline fortran_int 2103 + call_geev(GEEV_PARAMS_t<float>* params) 2104 + { 2105 + fortran_int rv; 2106 + LAPACK(sgeev)(&params->JOBVL, &params->JOBVR, 2107 + &params->N, params->A, &params->LDA, 2108 + params->WR, params->WI, 2109 + params->VLR, &params->LDVL, 2110 + params->VRR, &params->LDVR, 2111 + params->WORK, &params->LWORK, 2112 + &rv); 2113 + return rv; 2114 + } 2115 + 2116 + static inline fortran_int 2117 + call_geev(GEEV_PARAMS_t<double>* params) 2118 + { 2119 + fortran_int rv; 2120 + LAPACK(dgeev)(&params->JOBVL, &params->JOBVR, 2121 + &params->N, params->A, &params->LDA, 2122 + params->WR, params->WI, 2123 + params->VLR, &params->LDVL, 2124 + params->VRR, &params->LDVR, 2125 + params->WORK, &params->LWORK, 2126 + &rv); 2127 + return rv; 2128 + } 2129 + 2130 + 2131 + template<typename typ> 2132 + static inline int 2133 + init_geev(GEEV_PARAMS_t<typ> *params, char jobvl, char jobvr, fortran_int n, 2134 + scalar_trait) 2135 + { 2136 + npy_uint8 *mem_buff = NULL; 2137 + npy_uint8 *mem_buff2 = NULL; 2138 + npy_uint8 *a, *wr, *wi, *vlr, *vrr, *work, *w, *vl, *vr; 2139 + size_t safe_n = n; 2140 + size_t a_size = safe_n * safe_n * sizeof(typ); 2141 + size_t wr_size = safe_n * sizeof(typ); 2142 + size_t wi_size = safe_n * sizeof(typ); 2143 + size_t vlr_size = jobvl=='V' ? safe_n * safe_n * sizeof(typ) : 0; 2144 + size_t vrr_size = jobvr=='V' ? safe_n * safe_n * sizeof(typ) : 0; 2145 + size_t w_size = wr_size*2; 2146 + size_t vl_size = vlr_size*2; 2147 + size_t vr_size = vrr_size*2; 2148 + size_t work_count = 0; 2149 + fortran_int ld = fortran_int_max(n, 1); 2150 + 2151 + /* allocate data for known sizes (all but work) */ 2152 + mem_buff = (npy_uint8 *)malloc(a_size + wr_size + wi_size + 2153 + vlr_size + vrr_size + 2154 + w_size + vl_size + vr_size); 2155 + if (!mem_buff) { 2156 + goto error; 2157 + } 2158 + 2159 + a = mem_buff; 2160 + wr = a + a_size; 2161 + wi = wr + wr_size; 2162 + vlr = wi + wi_size; 2163 + vrr = vlr + vlr_size; 2164 + w = vrr + vrr_size; 2165 + vl = w + w_size; 2166 + vr = vl + vl_size; 2167 + 2168 + params->A = (typ*)a; 2169 + params->WR = (typ*)wr; 2170 + params->WI = (typ*)wi; 2171 + params->VLR = (typ*)vlr; 2172 + params->VRR = (typ*)vrr; 2173 + params->W = (typ*)w; 2174 + params->VL = (typ*)vl; 2175 + params->VR = (typ*)vr; 2176 + params->N = n; 2177 + params->LDA = ld; 2178 + params->LDVL = ld; 2179 + params->LDVR = ld; 2180 + params->JOBVL = jobvl; 2181 + params->JOBVR = jobvr; 2182 + 2183 + /* Work size query */ 2184 + { 2185 + typ work_size_query; 2186 + 2187 + params->LWORK = -1; 2188 + params->WORK = &work_size_query; 2189 + 2190 + if (call_geev(params) != 0) { 2191 + goto error; 2192 + } 2193 + 2194 + work_count = (size_t)work_size_query; 2195 + } 2196 + 2197 + mem_buff2 = (npy_uint8 *)malloc(work_count*sizeof(typ)); 2198 + if (!mem_buff2) { 2199 + goto error; 2200 + } 2201 + work = mem_buff2; 2202 + 2203 + params->LWORK = (fortran_int)work_count; 2204 + params->WORK = (typ*)work; 2205 + 2206 + return 1; 2207 + error: 2208 + free(mem_buff2); 2209 + free(mem_buff); 2210 + memset(params, 0, sizeof(*params)); 2211 + 2212 + return 0; 2213 + } 2214 + 2215 + template<typename complextyp, typename typ> 2216 + static inline void 2217 + mk_complex_array_from_real(complextyp *c, const typ *re, size_t n) 2218 + { 2219 + size_t iter; 2220 + for (iter = 0; iter < n; ++iter) { 2221 + c[iter].r = re[iter]; 2222 + c[iter].i = numeric_limits<typ>::zero; 2223 + } 2224 + } 2225 + 2226 + template<typename complextyp, typename typ> 2227 + static inline void 2228 + mk_complex_array(complextyp *c, 2229 + const typ *re, 2230 + const typ *im, 2231 + size_t n) 2232 + { 2233 + size_t iter; 2234 + for (iter = 0; iter < n; ++iter) { 2235 + c[iter].r = re[iter]; 2236 + c[iter].i = im[iter]; 2237 + } 2238 + } 2239 + 2240 + template<typename complextyp, typename typ> 2241 + static inline void 2242 + mk_complex_array_conjugate_pair(complextyp *c, 2243 + const typ *r, 2244 + size_t n) 2245 + { 2246 + size_t iter; 2247 + for (iter = 0; iter < n; ++iter) { 2248 + typ re = r[iter]; 2249 + typ im = r[iter+n]; 2250 + c[iter].r = re; 2251 + c[iter].i = im; 2252 + c[iter+n].r = re; 2253 + c[iter+n].i = -im; 2254 + } 2255 + } 2256 + 2257 + /* 2258 + * make the complex eigenvectors from the real array produced by sgeev/zgeev. 2259 + * c is the array where the results will be left. 2260 + * r is the source array of reals produced by sgeev/zgeev 2261 + * i is the eigenvalue imaginary part produced by sgeev/zgeev 2262 + * n is so that the order of the matrix is n by n 2263 + */ 2264 + template<typename complextyp, typename typ> 2265 + static inline void 2266 + mk_geev_complex_eigenvectors(complextyp *c, 2267 + const typ *r, 2268 + const typ *i, 2269 + size_t n) 2270 + { 2271 + size_t iter = 0; 2272 + while (iter < n) 2273 + { 2274 + if (i[iter] == numeric_limits<typ>::zero) { 2275 + /* eigenvalue was real, eigenvectors as well... */ 2276 + mk_complex_array_from_real(c, r, n); 2277 + c += n; 2278 + r += n; 2279 + iter ++; 2280 + } else { 2281 + /* eigenvalue was complex, generate a pair of eigenvectors */ 2282 + mk_complex_array_conjugate_pair(c, r, n); 2283 + c += 2*n; 2284 + r += 2*n; 2285 + iter += 2; 2286 + } 2287 + } 2288 + } 2289 + 2290 + 2291 + template<typename complextyp, typename typ> 2292 + static inline void 2293 + process_geev_results(GEEV_PARAMS_t<typ> *params, scalar_trait) 2294 + { 2295 + /* REAL versions of geev need the results to be translated 2296 + * into complex versions. This is the way to deal with imaginary 2297 + * results. In our gufuncs we will always return complex arrays! 2298 + */ 2299 + mk_complex_array((complextyp*)params->W, (typ*)params->WR, (typ*)params->WI, params->N); 2300 + 2301 + /* handle the eigenvectors */ 2302 + if ('V' == params->JOBVL) { 2303 + mk_geev_complex_eigenvectors((complextyp*)params->VL, (typ*)params->VLR, 2304 + (typ*)params->WI, params->N); 2305 + } 2306 + if ('V' == params->JOBVR) { 2307 + mk_geev_complex_eigenvectors((complextyp*)params->VR, (typ*)params->VRR, 2308 + (typ*)params->WI, params->N); 2309 + } 2310 + } 2311 + 2312 + 2313 + static inline fortran_int 2314 + call_geev(GEEV_PARAMS_t<fortran_complex>* params) 2315 + { 2316 + fortran_int rv; 2317 + 2318 + LAPACK(cgeev)(&params->JOBVL, &params->JOBVR, 2319 + &params->N, params->A, &params->LDA, 2320 + params->W, 2321 + params->VL, &params->LDVL, 2322 + params->VR, &params->LDVR, 2323 + params->WORK, &params->LWORK, 2324 + params->WR, /* actually RWORK */ 2325 + &rv); 2326 + return rv; 2327 + } 2328 + static inline fortran_int 2329 + call_geev(GEEV_PARAMS_t<fortran_doublecomplex>* params) 2330 + { 2331 + fortran_int rv; 2332 + 2333 + LAPACK(zgeev)(&params->JOBVL, &params->JOBVR, 2334 + &params->N, params->A, &params->LDA, 2335 + params->W, 2336 + params->VL, &params->LDVL, 2337 + params->VR, &params->LDVR, 2338 + params->WORK, &params->LWORK, 2339 + params->WR, /* actually RWORK */ 2340 + &rv); 2341 + return rv; 2342 + } 2343 + 2344 + template<typename ftyp> 2345 + static inline int 2346 + init_geev(GEEV_PARAMS_t<ftyp>* params, 2347 + char jobvl, 2348 + char jobvr, 2349 + fortran_int n, complex_trait) 2350 + { 2351 + using realtyp = basetype_t<ftyp>; 2352 + npy_uint8 *mem_buff = NULL; 2353 + npy_uint8 *mem_buff2 = NULL; 2354 + npy_uint8 *a, *w, *vl, *vr, *work, *rwork; 2355 + size_t safe_n = n; 2356 + size_t a_size = safe_n * safe_n * sizeof(ftyp); 2357 + size_t w_size = safe_n * sizeof(ftyp); 2358 + size_t vl_size = jobvl=='V'? safe_n * safe_n * sizeof(ftyp) : 0; 2359 + size_t vr_size = jobvr=='V'? safe_n * safe_n * sizeof(ftyp) : 0; 2360 + size_t rwork_size = 2 * safe_n * sizeof(realtyp); 2361 + size_t work_count = 0; 2362 + size_t total_size = a_size + w_size + vl_size + vr_size + rwork_size; 2363 + fortran_int ld = fortran_int_max(n, 1); 2364 + 2365 + mem_buff = (npy_uint8 *)malloc(total_size); 2366 + if (!mem_buff) { 2367 + goto error; 2368 + } 2369 + 2370 + a = mem_buff; 2371 + w = a + a_size; 2372 + vl = w + w_size; 2373 + vr = vl + vl_size; 2374 + rwork = vr + vr_size; 2375 + 2376 + params->A = (ftyp*)a; 2377 + params->WR = (realtyp*)rwork; 2378 + params->WI = NULL; 2379 + params->VLR = NULL; 2380 + params->VRR = NULL; 2381 + params->VL = (ftyp*)vl; 2382 + params->VR = (ftyp*)vr; 2383 + params->W = (ftyp*)w; 2384 + params->N = n; 2385 + params->LDA = ld; 2386 + params->LDVL = ld; 2387 + params->LDVR = ld; 2388 + params->JOBVL = jobvl; 2389 + params->JOBVR = jobvr; 2390 + 2391 + /* Work size query */ 2392 + { 2393 + ftyp work_size_query; 2394 + 2395 + params->LWORK = -1; 2396 + params->WORK = &work_size_query; 2397 + 2398 + if (call_geev(params) != 0) { 2399 + goto error; 2400 + } 2401 + 2402 + work_count = (size_t) work_size_query.r; 2403 + /* Fix a bug in lapack 3.0.0 */ 2404 + if(work_count == 0) work_count = 1; 2405 + } 2406 + 2407 + mem_buff2 = (npy_uint8 *)malloc(work_count*sizeof(ftyp)); 2408 + if (!mem_buff2) { 2409 + goto error; 2410 + } 2411 + 2412 + work = mem_buff2; 2413 + 2414 + params->LWORK = (fortran_int)work_count; 2415 + params->WORK = (ftyp*)work; 2416 + 2417 + return 1; 2418 + error: 2419 + free(mem_buff2); 2420 + free(mem_buff); 2421 + memset(params, 0, sizeof(*params)); 2422 + 2423 + return 0; 2424 + } 2425 + 2426 + template<typename complextyp, typename typ> 2427 + static inline void 2428 + process_geev_results(GEEV_PARAMS_t<typ> *NPY_UNUSED(params), complex_trait) 2429 + { 2430 + /* nothing to do here, complex versions are ready to copy out */ 2431 + } 2432 + 2433 + 2434 + 2435 + template<typename typ> 2436 + static inline void 2437 + release_geev(GEEV_PARAMS_t<typ> *params) 2438 + { 2439 + free(params->WORK); 2440 + free(params->A); 2441 + memset(params, 0, sizeof(*params)); 2442 + } 2443 + 2444 + template<typename fctype, typename ftype> 2445 + static inline void 2446 + eig_wrapper(char JOBVL, 2447 + char JOBVR, 2448 + char**args, 2449 + npy_intp const *dimensions, 2450 + npy_intp const *steps) 2451 + { 2452 + ptrdiff_t outer_steps[4]; 2453 + size_t iter; 2454 + size_t outer_dim = *dimensions++; 2455 + size_t op_count = 2; 2456 + int error_occurred = get_fp_invalid_and_clear(); 2457 + GEEV_PARAMS_t<ftype> geev_params; 2458 + 2459 + assert(JOBVL == 'N'); 2460 + 2461 + STACK_TRACE; 2462 + op_count += 'V'==JOBVL?1:0; 2463 + op_count += 'V'==JOBVR?1:0; 2464 + 2465 + for (iter = 0; iter < op_count; ++iter) { 2466 + outer_steps[iter] = (ptrdiff_t) steps[iter]; 2467 + } 2468 + steps += op_count; 2469 + 2470 + if (init_geev(&geev_params, 2471 + JOBVL, JOBVR, 2472 + (fortran_int)dimensions[0], dispatch_scalar<ftype>())) { 2473 + LINEARIZE_DATA_t a_in; 2474 + LINEARIZE_DATA_t w_out; 2475 + LINEARIZE_DATA_t vl_out; 2476 + LINEARIZE_DATA_t vr_out; 2477 + 2478 + init_linearize_data(&a_in, 2479 + geev_params.N, geev_params.N, 2480 + steps[1], steps[0]); 2481 + steps += 2; 2482 + init_linearize_data(&w_out, 2483 + 1, geev_params.N, 2484 + 0, steps[0]); 2485 + steps += 1; 2486 + if ('V' == geev_params.JOBVL) { 2487 + init_linearize_data(&vl_out, 2488 + geev_params.N, geev_params.N, 2489 + steps[1], steps[0]); 2490 + steps += 2; 2491 + } 2492 + if ('V' == geev_params.JOBVR) { 2493 + init_linearize_data(&vr_out, 2494 + geev_params.N, geev_params.N, 2495 + steps[1], steps[0]); 2496 + } 2497 + 2498 + for (iter = 0; iter < outer_dim; ++iter) { 2499 + int not_ok; 2500 + char **arg_iter = args; 2501 + /* copy the matrix in */ 2502 + linearize_matrix((ftype*)geev_params.A, (ftype*)*arg_iter++, &a_in); 2503 + not_ok = call_geev(&geev_params); 2504 + 2505 + if (!not_ok) { 2506 + process_geev_results<fctype>(&geev_params, 2507 + dispatch_scalar<ftype>{}); 2508 + delinearize_matrix((fctype*)*arg_iter++, 2509 + (fctype*)geev_params.W, 2510 + &w_out); 2511 + 2512 + if ('V' == geev_params.JOBVL) { 2513 + delinearize_matrix((fctype*)*arg_iter++, 2514 + (fctype*)geev_params.VL, 2515 + &vl_out); 2516 + } 2517 + if ('V' == geev_params.JOBVR) { 2518 + delinearize_matrix((fctype*)*arg_iter++, 2519 + (fctype*)geev_params.VR, 2520 + &vr_out); 2521 + } 2522 + } else { 2523 + /* geev failed */ 2524 + error_occurred = 1; 2525 + nan_matrix((fctype*)*arg_iter++, &w_out); 2526 + if ('V' == geev_params.JOBVL) { 2527 + nan_matrix((fctype*)*arg_iter++, &vl_out); 2528 + } 2529 + if ('V' == geev_params.JOBVR) { 2530 + nan_matrix((fctype*)*arg_iter++, &vr_out); 2531 + } 2532 + } 2533 + update_pointers((npy_uint8**)args, outer_steps, op_count); 2534 + } 2535 + 2536 + release_geev(&geev_params); 2537 + } 2538 + 2539 + set_fp_invalid_or_clear(error_occurred); 2540 + } 2541 + 2542 + template<typename fctype, typename ftype> 2543 + static void 2544 + eig(char **args, 2545 + npy_intp const *dimensions, 2546 + npy_intp const *steps, 2547 + void *NPY_UNUSED(func)) 2548 + { 2549 + eig_wrapper<fctype, ftype>('N', 'V', args, dimensions, steps); 2550 + } 2551 + 2552 + template<typename fctype, typename ftype> 2553 + static void 2554 + eigvals(char **args, 2555 + npy_intp const *dimensions, 2556 + npy_intp const *steps, 2557 + void *NPY_UNUSED(func)) 2558 + { 2559 + eig_wrapper<fctype, ftype>('N', 'N', args, dimensions, steps); 2560 + } 2561 + 2562 + 2563 + 2564 + /* -------------------------------------------------------------------------- */ 2565 + /* singular value decomposition */ 2566 + 2567 + template<typename ftyp> 2568 + struct GESDD_PARAMS_t 2569 + { 2570 + ftyp *A; 2571 + basetype_t<ftyp> *S; 2572 + ftyp *U; 2573 + ftyp *VT; 2574 + ftyp *WORK; 2575 + basetype_t<ftyp> *RWORK; 2576 + fortran_int *IWORK; 2577 + 2578 + fortran_int M; 2579 + fortran_int N; 2580 + fortran_int LDA; 2581 + fortran_int LDU; 2582 + fortran_int LDVT; 2583 + fortran_int LWORK; 2584 + char JOBZ; 2585 + } ; 2586 + 2587 + 2588 + template<typename ftyp> 2589 + static inline void 2590 + dump_gesdd_params(const char *name, 2591 + GESDD_PARAMS_t<ftyp> *params) 2592 + { 2593 + TRACE_TXT("\n%s:\n"\ 2594 + 2595 + "%14s: %18p\n"\ 2596 + "%14s: %18p\n"\ 2597 + "%14s: %18p\n"\ 2598 + "%14s: %18p\n"\ 2599 + "%14s: %18p\n"\ 2600 + "%14s: %18p\n"\ 2601 + "%14s: %18p\n"\ 2602 + 2603 + "%14s: %18d\n"\ 2604 + "%14s: %18d\n"\ 2605 + "%14s: %18d\n"\ 2606 + "%14s: %18d\n"\ 2607 + "%14s: %18d\n"\ 2608 + "%14s: %18d\n"\ 2609 + 2610 + "%14s: %15c'%c'\n", 2611 + 2612 + name, 2613 + 2614 + "A", params->A, 2615 + "S", params->S, 2616 + "U", params->U, 2617 + "VT", params->VT, 2618 + "WORK", params->WORK, 2619 + "RWORK", params->RWORK, 2620 + "IWORK", params->IWORK, 2621 + 2622 + "M", (int)params->M, 2623 + "N", (int)params->N, 2624 + "LDA", (int)params->LDA, 2625 + "LDU", (int)params->LDU, 2626 + "LDVT", (int)params->LDVT, 2627 + "LWORK", (int)params->LWORK, 2628 + 2629 + "JOBZ", ' ', params->JOBZ); 2630 + } 2631 + 2632 + static inline int 2633 + compute_urows_vtcolumns(char jobz, 2634 + fortran_int m, fortran_int n, 2635 + fortran_int *urows, fortran_int *vtcolumns) 2636 + { 2637 + fortran_int min_m_n = fortran_int_min(m, n); 2638 + switch(jobz) 2639 + { 2640 + case 'N': 2641 + *urows = 0; 2642 + *vtcolumns = 0; 2643 + break; 2644 + case 'A': 2645 + *urows = m; 2646 + *vtcolumns = n; 2647 + break; 2648 + case 'S': 2649 + { 2650 + *urows = min_m_n; 2651 + *vtcolumns = min_m_n; 2652 + } 2653 + break; 2654 + default: 2655 + return 0; 2656 + } 2657 + 2658 + return 1; 2659 + } 2660 + 2661 + static inline fortran_int 2662 + call_gesdd(GESDD_PARAMS_t<fortran_real> *params) 2663 + { 2664 + fortran_int rv; 2665 + LAPACK(sgesdd)(&params->JOBZ, &params->M, &params->N, 2666 + params->A, &params->LDA, 2667 + params->S, 2668 + params->U, &params->LDU, 2669 + params->VT, &params->LDVT, 2670 + params->WORK, &params->LWORK, 2671 + (fortran_int*)params->IWORK, 2672 + &rv); 2673 + return rv; 2674 + } 2675 + static inline fortran_int 2676 + call_gesdd(GESDD_PARAMS_t<fortran_doublereal> *params) 2677 + { 2678 + fortran_int rv; 2679 + LAPACK(dgesdd)(&params->JOBZ, &params->M, &params->N, 2680 + params->A, &params->LDA, 2681 + params->S, 2682 + params->U, &params->LDU, 2683 + params->VT, &params->LDVT, 2684 + params->WORK, &params->LWORK, 2685 + (fortran_int*)params->IWORK, 2686 + &rv); 2687 + return rv; 2688 + } 2689 + 2690 + template<typename ftyp> 2691 + static inline int 2692 + init_gesdd(GESDD_PARAMS_t<ftyp> *params, 2693 + char jobz, 2694 + fortran_int m, 2695 + fortran_int n, scalar_trait) 2696 + { 2697 + npy_uint8 *mem_buff = NULL; 2698 + npy_uint8 *mem_buff2 = NULL; 2699 + npy_uint8 *a, *s, *u, *vt, *work, *iwork; 2700 + size_t safe_m = m; 2701 + size_t safe_n = n; 2702 + size_t a_size = safe_m * safe_n * sizeof(ftyp); 2703 + fortran_int min_m_n = fortran_int_min(m, n); 2704 + size_t safe_min_m_n = min_m_n; 2705 + size_t s_size = safe_min_m_n * sizeof(ftyp); 2706 + fortran_int u_row_count, vt_column_count; 2707 + size_t safe_u_row_count, safe_vt_column_count; 2708 + size_t u_size, vt_size; 2709 + fortran_int work_count; 2710 + size_t work_size; 2711 + size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int); 2712 + fortran_int ld = fortran_int_max(m, 1); 2713 + 2714 + if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) { 2715 + goto error; 2716 + } 2717 + 2718 + safe_u_row_count = u_row_count; 2719 + safe_vt_column_count = vt_column_count; 2720 + 2721 + u_size = safe_u_row_count * safe_m * sizeof(ftyp); 2722 + vt_size = safe_n * safe_vt_column_count * sizeof(ftyp); 2723 + 2724 + mem_buff = (npy_uint8 *)malloc(a_size + s_size + u_size + vt_size + iwork_size); 2725 + 2726 + if (!mem_buff) { 2727 + goto error; 2728 + } 2729 + 2730 + a = mem_buff; 2731 + s = a + a_size; 2732 + u = s + s_size; 2733 + vt = u + u_size; 2734 + iwork = vt + vt_size; 2735 + 2736 + /* fix vt_column_count so that it is a valid lapack parameter (0 is not) */ 2737 + vt_column_count = fortran_int_max(1, vt_column_count); 2738 + 2739 + params->M = m; 2740 + params->N = n; 2741 + params->A = (ftyp*)a; 2742 + params->S = (ftyp*)s; 2743 + params->U = (ftyp*)u; 2744 + params->VT = (ftyp*)vt; 2745 + params->RWORK = NULL; 2746 + params->IWORK = (fortran_int*)iwork; 2747 + params->LDA = ld; 2748 + params->LDU = ld; 2749 + params->LDVT = vt_column_count; 2750 + params->JOBZ = jobz; 2751 + 2752 + /* Work size query */ 2753 + { 2754 + ftyp work_size_query; 2755 + 2756 + params->LWORK = -1; 2757 + params->WORK = &work_size_query; 2758 + 2759 + if (call_gesdd(params) != 0) { 2760 + goto error; 2761 + } 2762 + 2763 + work_count = (fortran_int)work_size_query; 2764 + /* Fix a bug in lapack 3.0.0 */ 2765 + if(work_count == 0) work_count = 1; 2766 + work_size = (size_t)work_count * sizeof(ftyp); 2767 + } 2768 + 2769 + mem_buff2 = (npy_uint8 *)malloc(work_size); 2770 + if (!mem_buff2) { 2771 + goto error; 2772 + } 2773 + 2774 + work = mem_buff2; 2775 + 2776 + params->LWORK = work_count; 2777 + params->WORK = (ftyp*)work; 2778 + 2779 + return 1; 2780 + error: 2781 + TRACE_TXT("%s failed init\n", __FUNCTION__); 2782 + free(mem_buff); 2783 + free(mem_buff2); 2784 + memset(params, 0, sizeof(*params)); 2785 + 2786 + return 0; 2787 + } 2788 + 2789 + static inline fortran_int 2790 + call_gesdd(GESDD_PARAMS_t<fortran_complex> *params) 2791 + { 2792 + fortran_int rv; 2793 + LAPACK(cgesdd)(&params->JOBZ, &params->M, &params->N, 2794 + params->A, &params->LDA, 2795 + params->S, 2796 + params->U, &params->LDU, 2797 + params->VT, &params->LDVT, 2798 + params->WORK, &params->LWORK, 2799 + params->RWORK, 2800 + params->IWORK, 2801 + &rv); 2802 + return rv; 2803 + } 2804 + static inline fortran_int 2805 + call_gesdd(GESDD_PARAMS_t<fortran_doublecomplex> *params) 2806 + { 2807 + fortran_int rv; 2808 + LAPACK(zgesdd)(&params->JOBZ, &params->M, &params->N, 2809 + params->A, &params->LDA, 2810 + params->S, 2811 + params->U, &params->LDU, 2812 + params->VT, &params->LDVT, 2813 + params->WORK, &params->LWORK, 2814 + params->RWORK, 2815 + params->IWORK, 2816 + &rv); 2817 + return rv; 2818 + } 2819 + 2820 + template<typename ftyp> 2821 + static inline int 2822 + init_gesdd(GESDD_PARAMS_t<ftyp> *params, 2823 + char jobz, 2824 + fortran_int m, 2825 + fortran_int n, complex_trait) 2826 + { 2827 + using frealtyp = basetype_t<ftyp>; 2828 + npy_uint8 *mem_buff = NULL, *mem_buff2 = NULL; 2829 + npy_uint8 *a,*s, *u, *vt, *work, *rwork, *iwork; 2830 + size_t a_size, s_size, u_size, vt_size, work_size, rwork_size, iwork_size; 2831 + size_t safe_u_row_count, safe_vt_column_count; 2832 + fortran_int u_row_count, vt_column_count, work_count; 2833 + size_t safe_m = m; 2834 + size_t safe_n = n; 2835 + fortran_int min_m_n = fortran_int_min(m, n); 2836 + size_t safe_min_m_n = min_m_n; 2837 + fortran_int ld = fortran_int_max(m, 1); 2838 + 2839 + if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) { 2840 + goto error; 2841 + } 2842 + 2843 + safe_u_row_count = u_row_count; 2844 + safe_vt_column_count = vt_column_count; 2845 + 2846 + a_size = safe_m * safe_n * sizeof(ftyp); 2847 + s_size = safe_min_m_n * sizeof(frealtyp); 2848 + u_size = safe_u_row_count * safe_m * sizeof(ftyp); 2849 + vt_size = safe_n * safe_vt_column_count * sizeof(ftyp); 2850 + rwork_size = 'N'==jobz? 2851 + (7 * safe_min_m_n) : 2852 + (5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n); 2853 + rwork_size *= sizeof(ftyp); 2854 + iwork_size = 8 * safe_min_m_n* sizeof(fortran_int); 2855 + 2856 + mem_buff = (npy_uint8 *)malloc(a_size + 2857 + s_size + 2858 + u_size + 2859 + vt_size + 2860 + rwork_size + 2861 + iwork_size); 2862 + if (!mem_buff) { 2863 + goto error; 2864 + } 2865 + 2866 + a = mem_buff; 2867 + s = a + a_size; 2868 + u = s + s_size; 2869 + vt = u + u_size; 2870 + rwork = vt + vt_size; 2871 + iwork = rwork + rwork_size; 2872 + 2873 + /* fix vt_column_count so that it is a valid lapack parameter (0 is not) */ 2874 + vt_column_count = fortran_int_max(1, vt_column_count); 2875 + 2876 + params->A = (ftyp*)a; 2877 + params->S = (frealtyp*)s; 2878 + params->U = (ftyp*)u; 2879 + params->VT = (ftyp*)vt; 2880 + params->RWORK = (frealtyp*)rwork; 2881 + params->IWORK = (fortran_int*)iwork; 2882 + params->M = m; 2883 + params->N = n; 2884 + params->LDA = ld; 2885 + params->LDU = ld; 2886 + params->LDVT = vt_column_count; 2887 + params->JOBZ = jobz; 2888 + 2889 + /* Work size query */ 2890 + { 2891 + ftyp work_size_query; 2892 + 2893 + params->LWORK = -1; 2894 + params->WORK = &work_size_query; 2895 + 2896 + if (call_gesdd(params) != 0) { 2897 + goto error; 2898 + } 2899 + 2900 + work_count = (fortran_int)(*(frealtyp*)&work_size_query); 2901 + /* Fix a bug in lapack 3.0.0 */ 2902 + if(work_count == 0) work_count = 1; 2903 + work_size = (size_t)work_count * sizeof(ftyp); 2904 + } 2905 + 2906 + mem_buff2 = (npy_uint8 *)malloc(work_size); 2907 + if (!mem_buff2) { 2908 + goto error; 2909 + } 2910 + 2911 + work = mem_buff2; 2912 + 2913 + params->LWORK = work_count; 2914 + params->WORK = (ftyp*)work; 2915 + 2916 + return 1; 2917 + error: 2918 + TRACE_TXT("%s failed init\n", __FUNCTION__); 2919 + free(mem_buff2); 2920 + free(mem_buff); 2921 + memset(params, 0, sizeof(*params)); 2922 + 2923 + return 0; 2924 + } 2925 + 2926 + template<typename typ> 2927 + static inline void 2928 + release_gesdd(GESDD_PARAMS_t<typ>* params) 2929 + { 2930 + /* A and WORK contain allocated blocks */ 2931 + free(params->A); 2932 + free(params->WORK); 2933 + memset(params, 0, sizeof(*params)); 2934 + } 2935 + 2936 + template<typename typ> 2937 + static inline void 2938 + svd_wrapper(char JOBZ, 2939 + char **args, 2940 + npy_intp const *dimensions, 2941 + npy_intp const *steps) 2942 + { 2943 + using basetyp = basetype_t<typ>; 2944 + ptrdiff_t outer_steps[4]; 2945 + int error_occurred = get_fp_invalid_and_clear(); 2946 + size_t iter; 2947 + size_t outer_dim = *dimensions++; 2948 + size_t op_count = (JOBZ=='N')?2:4; 2949 + GESDD_PARAMS_t<typ> params; 2950 + 2951 + for (iter = 0; iter < op_count; ++iter) { 2952 + outer_steps[iter] = (ptrdiff_t) steps[iter]; 2953 + } 2954 + steps += op_count; 2955 + 2956 + if (init_gesdd(&params, 2957 + JOBZ, 2958 + (fortran_int)dimensions[0], 2959 + (fortran_int)dimensions[1], 2960 + dispatch_scalar<typ>())) { 2961 + LINEARIZE_DATA_t a_in, u_out, s_out, v_out; 2962 + fortran_int min_m_n = params.M < params.N ? params.M : params.N; 2963 + 2964 + init_linearize_data(&a_in, params.N, params.M, steps[1], steps[0]); 2965 + if ('N' == params.JOBZ) { 2966 + /* only the singular values are wanted */ 2967 + init_linearize_data(&s_out, 1, min_m_n, 0, steps[2]); 2968 + } else { 2969 + fortran_int u_columns, v_rows; 2970 + if ('S' == params.JOBZ) { 2971 + u_columns = min_m_n; 2972 + v_rows = min_m_n; 2973 + } else { /* JOBZ == 'A' */ 2974 + u_columns = params.M; 2975 + v_rows = params.N; 2976 + } 2977 + init_linearize_data(&u_out, 2978 + u_columns, params.M, 2979 + steps[3], steps[2]); 2980 + init_linearize_data(&s_out, 2981 + 1, min_m_n, 2982 + 0, steps[4]); 2983 + init_linearize_data(&v_out, 2984 + params.N, v_rows, 2985 + steps[6], steps[5]); 2986 + } 2987 + 2988 + for (iter = 0; iter < outer_dim; ++iter) { 2989 + int not_ok; 2990 + /* copy the matrix in */ 2991 + linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 2992 + not_ok = call_gesdd(&params); 2993 + if (!not_ok) { 2994 + if ('N' == params.JOBZ) { 2995 + delinearize_matrix((basetyp*)args[1], (basetyp*)params.S, &s_out); 2996 + } else { 2997 + if ('A' == params.JOBZ && min_m_n == 0) { 2998 + /* Lapack has betrayed us and left these uninitialized, 2999 + * so produce an identity matrix for whichever of u 3000 + * and v is not empty. 3001 + */ 3002 + identity_matrix((typ*)params.U, params.M); 3003 + identity_matrix((typ*)params.VT, params.N); 3004 + } 3005 + 3006 + delinearize_matrix((typ*)args[1], (typ*)params.U, &u_out); 3007 + delinearize_matrix((basetyp*)args[2], (basetyp*)params.S, &s_out); 3008 + delinearize_matrix((typ*)args[3], (typ*)params.VT, &v_out); 3009 + } 3010 + } else { 3011 + error_occurred = 1; 3012 + if ('N' == params.JOBZ) { 3013 + nan_matrix((basetyp*)args[1], &s_out); 3014 + } else { 3015 + nan_matrix((typ*)args[1], &u_out); 3016 + nan_matrix((basetyp*)args[2], &s_out); 3017 + nan_matrix((typ*)args[3], &v_out); 3018 + } 3019 + } 3020 + update_pointers((npy_uint8**)args, outer_steps, op_count); 3021 + } 3022 + 3023 + release_gesdd(&params); 3024 + } 3025 + 3026 + set_fp_invalid_or_clear(error_occurred); 3027 + } 3028 + 3029 + 3030 + template<typename typ> 3031 + static void 3032 + svd_N(char **args, 3033 + npy_intp const *dimensions, 3034 + npy_intp const *steps, 3035 + void *NPY_UNUSED(func)) 3036 + { 3037 + svd_wrapper<fortran_type_t<typ>>('N', args, dimensions, steps); 3038 + } 3039 + 3040 + template<typename typ> 3041 + static void 3042 + svd_S(char **args, 3043 + npy_intp const *dimensions, 3044 + npy_intp const *steps, 3045 + void *NPY_UNUSED(func)) 3046 + { 3047 + svd_wrapper<fortran_type_t<typ>>('S', args, dimensions, steps); 3048 + } 3049 + 3050 + template<typename typ> 3051 + static void 3052 + svd_A(char **args, 3053 + npy_intp const *dimensions, 3054 + npy_intp const *steps, 3055 + void *NPY_UNUSED(func)) 3056 + { 3057 + svd_wrapper<fortran_type_t<typ>>('A', args, dimensions, steps); 3058 + } 3059 + 3060 + /* -------------------------------------------------------------------------- */ 3061 + /* qr (modes - r, raw) */ 3062 + 3063 + template<typename typ> 3064 + struct GEQRF_PARAMS_t 3065 + { 3066 + fortran_int M; 3067 + fortran_int N; 3068 + typ *A; 3069 + fortran_int LDA; 3070 + typ* TAU; 3071 + typ *WORK; 3072 + fortran_int LWORK; 3073 + }; 3074 + 3075 + 3076 + template<typename typ> 3077 + static inline void 3078 + dump_geqrf_params(const char *name, 3079 + GEQRF_PARAMS_t<typ> *params) 3080 + { 3081 + TRACE_TXT("\n%s:\n"\ 3082 + 3083 + "%14s: %18p\n"\ 3084 + "%14s: %18p\n"\ 3085 + "%14s: %18p\n"\ 3086 + "%14s: %18d\n"\ 3087 + "%14s: %18d\n"\ 3088 + "%14s: %18d\n"\ 3089 + "%14s: %18d\n", 3090 + 3091 + name, 3092 + 3093 + "A", params->A, 3094 + "TAU", params->TAU, 3095 + "WORK", params->WORK, 3096 + 3097 + "M", (int)params->M, 3098 + "N", (int)params->N, 3099 + "LDA", (int)params->LDA, 3100 + "LWORK", (int)params->LWORK); 3101 + } 3102 + 3103 + static inline fortran_int 3104 + call_geqrf(GEQRF_PARAMS_t<double> *params) 3105 + { 3106 + fortran_int rv; 3107 + LAPACK(dgeqrf)(&params->M, &params->N, 3108 + params->A, &params->LDA, 3109 + params->TAU, 3110 + params->WORK, &params->LWORK, 3111 + &rv); 3112 + return rv; 3113 + } 3114 + static inline fortran_int 3115 + call_geqrf(GEQRF_PARAMS_t<f2c_doublecomplex> *params) 3116 + { 3117 + fortran_int rv; 3118 + LAPACK(zgeqrf)(&params->M, &params->N, 3119 + params->A, &params->LDA, 3120 + params->TAU, 3121 + params->WORK, &params->LWORK, 3122 + &rv); 3123 + return rv; 3124 + } 3125 + 3126 + 3127 + static inline int 3128 + init_geqrf(GEQRF_PARAMS_t<fortran_doublereal> *params, 3129 + fortran_int m, 3130 + fortran_int n) 3131 + { 3132 + using ftyp = fortran_doublereal; 3133 + npy_uint8 *mem_buff = NULL; 3134 + npy_uint8 *mem_buff2 = NULL; 3135 + npy_uint8 *a, *tau, *work; 3136 + fortran_int min_m_n = fortran_int_min(m, n); 3137 + size_t safe_min_m_n = min_m_n; 3138 + size_t safe_m = m; 3139 + size_t safe_n = n; 3140 + 3141 + size_t a_size = safe_m * safe_n * sizeof(ftyp); 3142 + size_t tau_size = safe_min_m_n * sizeof(ftyp); 3143 + 3144 + fortran_int work_count; 3145 + size_t work_size; 3146 + fortran_int lda = fortran_int_max(1, m); 3147 + 3148 + mem_buff = (npy_uint8 *)malloc(a_size + tau_size); 3149 + 3150 + if (!mem_buff) 3151 + goto error; 3152 + 3153 + a = mem_buff; 3154 + tau = a + a_size; 3155 + memset(tau, 0, tau_size); 3156 + 3157 + 3158 + params->M = m; 3159 + params->N = n; 3160 + params->A = (ftyp*)a; 3161 + params->TAU = (ftyp*)tau; 3162 + params->LDA = lda; 3163 + 3164 + { 3165 + /* compute optimal work size */ 3166 + 3167 + ftyp work_size_query; 3168 + 3169 + params->WORK = &work_size_query; 3170 + params->LWORK = -1; 3171 + 3172 + if (call_geqrf(params) != 0) 3173 + goto error; 3174 + 3175 + work_count = (fortran_int) *(ftyp*) params->WORK; 3176 + 3177 + } 3178 + 3179 + params->LWORK = fortran_int_max(fortran_int_max(1, n), work_count); 3180 + 3181 + work_size = (size_t) params->LWORK * sizeof(ftyp); 3182 + mem_buff2 = (npy_uint8 *)malloc(work_size); 3183 + if (!mem_buff2) 3184 + goto error; 3185 + 3186 + work = mem_buff2; 3187 + 3188 + params->WORK = (ftyp*)work; 3189 + 3190 + return 1; 3191 + error: 3192 + TRACE_TXT("%s failed init\n", __FUNCTION__); 3193 + free(mem_buff); 3194 + free(mem_buff2); 3195 + memset(params, 0, sizeof(*params)); 3196 + 3197 + return 0; 3198 + } 3199 + 3200 + 3201 + static inline int 3202 + init_geqrf(GEQRF_PARAMS_t<fortran_doublecomplex> *params, 3203 + fortran_int m, 3204 + fortran_int n) 3205 + { 3206 + using ftyp = fortran_doublecomplex; 3207 + npy_uint8 *mem_buff = NULL; 3208 + npy_uint8 *mem_buff2 = NULL; 3209 + npy_uint8 *a, *tau, *work; 3210 + fortran_int min_m_n = fortran_int_min(m, n); 3211 + size_t safe_min_m_n = min_m_n; 3212 + size_t safe_m = m; 3213 + size_t safe_n = n; 3214 + 3215 + size_t a_size = safe_m * safe_n * sizeof(ftyp); 3216 + size_t tau_size = safe_min_m_n * sizeof(ftyp); 3217 + 3218 + fortran_int work_count; 3219 + size_t work_size; 3220 + fortran_int lda = fortran_int_max(1, m); 3221 + 3222 + mem_buff = (npy_uint8 *)malloc(a_size + tau_size); 3223 + 3224 + if (!mem_buff) 3225 + goto error; 3226 + 3227 + a = mem_buff; 3228 + tau = a + a_size; 3229 + memset(tau, 0, tau_size); 3230 + 3231 + 3232 + params->M = m; 3233 + params->N = n; 3234 + params->A = (ftyp*)a; 3235 + params->TAU = (ftyp*)tau; 3236 + params->LDA = lda; 3237 + 3238 + { 3239 + /* compute optimal work size */ 3240 + 3241 + ftyp work_size_query; 3242 + 3243 + params->WORK = &work_size_query; 3244 + params->LWORK = -1; 3245 + 3246 + if (call_geqrf(params) != 0) 3247 + goto error; 3248 + 3249 + work_count = (fortran_int) ((ftyp*)params->WORK)->r; 3250 + 3251 + } 3252 + 3253 + params->LWORK = fortran_int_max(fortran_int_max(1, n), 3254 + work_count); 3255 + 3256 + work_size = (size_t) params->LWORK * sizeof(ftyp); 3257 + 3258 + mem_buff2 = (npy_uint8 *)malloc(work_size); 3259 + if (!mem_buff2) 3260 + goto error; 3261 + 3262 + work = mem_buff2; 3263 + 3264 + params->WORK = (ftyp*)work; 3265 + 3266 + return 1; 3267 + error: 3268 + TRACE_TXT("%s failed init\n", __FUNCTION__); 3269 + free(mem_buff); 3270 + free(mem_buff2); 3271 + memset(params, 0, sizeof(*params)); 3272 + 3273 + return 0; 3274 + } 3275 + 3276 + 3277 + template<typename ftyp> 3278 + static inline void 3279 + release_geqrf(GEQRF_PARAMS_t<ftyp>* params) 3280 + { 3281 + /* A and WORK contain allocated blocks */ 3282 + free(params->A); 3283 + free(params->WORK); 3284 + memset(params, 0, sizeof(*params)); 3285 + } 3286 + 3287 + template<typename typ> 3288 + static void 3289 + qr_r_raw(char **args, npy_intp const *dimensions, npy_intp const *steps, 3290 + void *NPY_UNUSED(func)) 3291 + { 3292 + using ftyp = fortran_type_t<typ>; 3293 + 3294 + GEQRF_PARAMS_t<ftyp> params; 3295 + int error_occurred = get_fp_invalid_and_clear(); 3296 + fortran_int n, m; 3297 + 3298 + INIT_OUTER_LOOP_2 3299 + 3300 + m = (fortran_int)dimensions[0]; 3301 + n = (fortran_int)dimensions[1]; 3302 + 3303 + if (init_geqrf(&params, m, n)) { 3304 + LINEARIZE_DATA_t a_in, tau_out; 3305 + 3306 + init_linearize_data(&a_in, n, m, steps[1], steps[0]); 3307 + init_linearize_data(&tau_out, 1, fortran_int_min(m, n), 1, steps[2]); 3308 + 3309 + BEGIN_OUTER_LOOP_2 3310 + int not_ok; 3311 + linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 3312 + not_ok = call_geqrf(&params); 3313 + if (!not_ok) { 3314 + delinearize_matrix((typ*)args[0], (typ*)params.A, &a_in); 3315 + delinearize_matrix((typ*)args[1], (typ*)params.TAU, &tau_out); 3316 + } else { 3317 + error_occurred = 1; 3318 + nan_matrix((typ*)args[1], &tau_out); 3319 + } 3320 + END_OUTER_LOOP 3321 + 3322 + release_geqrf(&params); 3323 + } 3324 + 3325 + set_fp_invalid_or_clear(error_occurred); 3326 + } 3327 + 3328 + 3329 + /* -------------------------------------------------------------------------- */ 3330 + /* qr common code (modes - reduced and complete) */ 3331 + 3332 + template<typename typ> 3333 + struct GQR_PARAMS_t 3334 + { 3335 + fortran_int M; 3336 + fortran_int MC; 3337 + fortran_int MN; 3338 + void* A; 3339 + typ *Q; 3340 + fortran_int LDA; 3341 + typ* TAU; 3342 + typ *WORK; 3343 + fortran_int LWORK; 3344 + } ; 3345 + 3346 + static inline fortran_int 3347 + call_gqr(GQR_PARAMS_t<double> *params) 3348 + { 3349 + fortran_int rv; 3350 + LAPACK(dorgqr)(&params->M, &params->MC, &params->MN, 3351 + params->Q, &params->LDA, 3352 + params->TAU, 3353 + params->WORK, &params->LWORK, 3354 + &rv); 3355 + return rv; 3356 + } 3357 + static inline fortran_int 3358 + call_gqr(GQR_PARAMS_t<f2c_doublecomplex> *params) 3359 + { 3360 + fortran_int rv; 3361 + LAPACK(zungqr)(&params->M, &params->MC, &params->MN, 3362 + params->Q, &params->LDA, 3363 + params->TAU, 3364 + params->WORK, &params->LWORK, 3365 + &rv); 3366 + return rv; 3367 + } 3368 + 3369 + static inline int 3370 + init_gqr_common(GQR_PARAMS_t<fortran_doublereal> *params, 3371 + fortran_int m, 3372 + fortran_int n, 3373 + fortran_int mc) 3374 + { 3375 + using ftyp = fortran_doublereal; 3376 + npy_uint8 *mem_buff = NULL; 3377 + npy_uint8 *mem_buff2 = NULL; 3378 + npy_uint8 *a, *q, *tau, *work; 3379 + fortran_int min_m_n = fortran_int_min(m, n); 3380 + size_t safe_mc = mc; 3381 + size_t safe_min_m_n = min_m_n; 3382 + size_t safe_m = m; 3383 + size_t safe_n = n; 3384 + size_t a_size = safe_m * safe_n * sizeof(ftyp); 3385 + size_t q_size = safe_m * safe_mc * sizeof(ftyp); 3386 + size_t tau_size = safe_min_m_n * sizeof(ftyp); 3387 + 3388 + fortran_int work_count; 3389 + size_t work_size; 3390 + fortran_int lda = fortran_int_max(1, m); 3391 + 3392 + mem_buff = (npy_uint8 *)malloc(q_size + tau_size + a_size); 3393 + 3394 + if (!mem_buff) 3395 + goto error; 3396 + 3397 + q = mem_buff; 3398 + tau = q + q_size; 3399 + a = tau + tau_size; 3400 + 3401 + 3402 + params->M = m; 3403 + params->MC = mc; 3404 + params->MN = min_m_n; 3405 + params->A = a; 3406 + params->Q = (ftyp*)q; 3407 + params->TAU = (ftyp*)tau; 3408 + params->LDA = lda; 3409 + 3410 + { 3411 + /* compute optimal work size */ 3412 + ftyp work_size_query; 3413 + 3414 + params->WORK = &work_size_query; 3415 + params->LWORK = -1; 3416 + 3417 + if (call_gqr(params) != 0) 3418 + goto error; 3419 + 3420 + work_count = (fortran_int) *(ftyp*) params->WORK; 3421 + 3422 + } 3423 + 3424 + params->LWORK = fortran_int_max(fortran_int_max(1, n), work_count); 3425 + 3426 + work_size = (size_t) params->LWORK * sizeof(ftyp); 3427 + 3428 + mem_buff2 = (npy_uint8 *)malloc(work_size); 3429 + if (!mem_buff2) 3430 + goto error; 3431 + 3432 + work = mem_buff2; 3433 + 3434 + params->WORK = (ftyp*)work; 3435 + 3436 + return 1; 3437 + error: 3438 + TRACE_TXT("%s failed init\n", __FUNCTION__); 3439 + free(mem_buff); 3440 + free(mem_buff2); 3441 + memset(params, 0, sizeof(*params)); 3442 + 3443 + return 0; 3444 + } 3445 + 3446 + 3447 + static inline int 3448 + init_gqr_common(GQR_PARAMS_t<fortran_doublecomplex> *params, 3449 + fortran_int m, 3450 + fortran_int n, 3451 + fortran_int mc) 3452 + { 3453 + using ftyp=fortran_doublecomplex; 3454 + npy_uint8 *mem_buff = NULL; 3455 + npy_uint8 *mem_buff2 = NULL; 3456 + npy_uint8 *a, *q, *tau, *work; 3457 + fortran_int min_m_n = fortran_int_min(m, n); 3458 + size_t safe_mc = mc; 3459 + size_t safe_min_m_n = min_m_n; 3460 + size_t safe_m = m; 3461 + size_t safe_n = n; 3462 + 3463 + size_t a_size = safe_m * safe_n * sizeof(ftyp); 3464 + size_t q_size = safe_m * safe_mc * sizeof(ftyp); 3465 + size_t tau_size = safe_min_m_n * sizeof(ftyp); 3466 + 3467 + fortran_int work_count; 3468 + size_t work_size; 3469 + fortran_int lda = fortran_int_max(1, m); 3470 + 3471 + mem_buff = (npy_uint8 *)malloc(q_size + tau_size + a_size); 3472 + 3473 + if (!mem_buff) 3474 + goto error; 3475 + 3476 + q = mem_buff; 3477 + tau = q + q_size; 3478 + a = tau + tau_size; 3479 + 3480 + 3481 + params->M = m; 3482 + params->MC = mc; 3483 + params->MN = min_m_n; 3484 + params->A = a; 3485 + params->Q = (ftyp*)q; 3486 + params->TAU = (ftyp*)tau; 3487 + params->LDA = lda; 3488 + 3489 + { 3490 + /* compute optimal work size */ 3491 + ftyp work_size_query; 3492 + 3493 + params->WORK = &work_size_query; 3494 + params->LWORK = -1; 3495 + 3496 + if (call_gqr(params) != 0) 3497 + goto error; 3498 + 3499 + work_count = (fortran_int) ((ftyp*)params->WORK)->r; 3500 + 3501 + } 3502 + 3503 + params->LWORK = fortran_int_max(fortran_int_max(1, n), 3504 + work_count); 3505 + 3506 + work_size = (size_t) params->LWORK * sizeof(ftyp); 3507 + 3508 + mem_buff2 = (npy_uint8 *)malloc(work_size); 3509 + if (!mem_buff2) 3510 + goto error; 3511 + 3512 + work = mem_buff2; 3513 + 3514 + params->WORK = (ftyp*)work; 3515 + params->LWORK = work_count; 3516 + 3517 + return 1; 3518 + error: 3519 + TRACE_TXT("%s failed init\n", __FUNCTION__); 3520 + free(mem_buff); 3521 + free(mem_buff2); 3522 + memset(params, 0, sizeof(*params)); 3523 + 3524 + return 0; 3525 + } 3526 + 3527 + /* -------------------------------------------------------------------------- */ 3528 + /* qr (modes - reduced) */ 3529 + 3530 + 3531 + template<typename typ> 3532 + static inline void 3533 + dump_gqr_params(const char *name, 3534 + GQR_PARAMS_t<typ> *params) 3535 + { 3536 + TRACE_TXT("\n%s:\n"\ 3537 + 3538 + "%14s: %18p\n"\ 3539 + "%14s: %18p\n"\ 3540 + "%14s: %18p\n"\ 3541 + "%14s: %18d\n"\ 3542 + "%14s: %18d\n"\ 3543 + "%14s: %18d\n"\ 3544 + "%14s: %18d\n"\ 3545 + "%14s: %18d\n", 3546 + 3547 + name, 3548 + 3549 + "Q", params->Q, 3550 + "TAU", params->TAU, 3551 + "WORK", params->WORK, 3552 + 3553 + "M", (int)params->M, 3554 + "MC", (int)params->MC, 3555 + "MN", (int)params->MN, 3556 + "LDA", (int)params->LDA, 3557 + "LWORK", (int)params->LWORK); 3558 + } 3559 + 3560 + template<typename ftyp> 3561 + static inline int 3562 + init_gqr(GQR_PARAMS_t<ftyp> *params, 3563 + fortran_int m, 3564 + fortran_int n) 3565 + { 3566 + return init_gqr_common( 3567 + params, m, n, 3568 + fortran_int_min(m, n)); 3569 + } 3570 + 3571 + 3572 + template<typename typ> 3573 + static inline void 3574 + release_gqr(GQR_PARAMS_t<typ>* params) 3575 + { 3576 + /* A and WORK contain allocated blocks */ 3577 + free(params->Q); 3578 + free(params->WORK); 3579 + memset(params, 0, sizeof(*params)); 3580 + } 3581 + 3582 + template<typename typ> 3583 + static void 3584 + qr_reduced(char **args, npy_intp const *dimensions, npy_intp const *steps, 3585 + void *NPY_UNUSED(func)) 3586 + { 3587 + using ftyp = fortran_type_t<typ>; 3588 + GQR_PARAMS_t<ftyp> params; 3589 + int error_occurred = get_fp_invalid_and_clear(); 3590 + fortran_int n, m; 3591 + 3592 + INIT_OUTER_LOOP_3 3593 + 3594 + m = (fortran_int)dimensions[0]; 3595 + n = (fortran_int)dimensions[1]; 3596 + 3597 + if (init_gqr(&params, m, n)) { 3598 + LINEARIZE_DATA_t a_in, tau_in, q_out; 3599 + 3600 + init_linearize_data(&a_in, n, m, steps[1], steps[0]); 3601 + init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]); 3602 + init_linearize_data(&q_out, fortran_int_min(m, n), m, steps[4], steps[3]); 3603 + 3604 + BEGIN_OUTER_LOOP_3 3605 + int not_ok; 3606 + linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 3607 + linearize_matrix((typ*)params.Q, (typ*)args[0], &a_in); 3608 + linearize_matrix((typ*)params.TAU, (typ*)args[1], &tau_in); 3609 + not_ok = call_gqr(&params); 3610 + if (!not_ok) { 3611 + delinearize_matrix((typ*)args[2], (typ*)params.Q, &q_out); 3612 + } else { 3613 + error_occurred = 1; 3614 + nan_matrix((typ*)args[2], &q_out); 3615 + } 3616 + END_OUTER_LOOP 3617 + 3618 + release_gqr(&params); 3619 + } 3620 + 3621 + set_fp_invalid_or_clear(error_occurred); 3622 + } 3623 + 3624 + /* -------------------------------------------------------------------------- */ 3625 + /* qr (modes - complete) */ 3626 + 3627 + template<typename ftyp> 3628 + static inline int 3629 + init_gqr_complete(GQR_PARAMS_t<ftyp> *params, 3630 + fortran_int m, 3631 + fortran_int n) 3632 + { 3633 + return init_gqr_common(params, m, n, m); 3634 + } 3635 + 3636 + 3637 + template<typename typ> 3638 + static void 3639 + qr_complete(char **args, npy_intp const *dimensions, npy_intp const *steps, 3640 + void *NPY_UNUSED(func)) 3641 + { 3642 + using ftyp = fortran_type_t<typ>; 3643 + GQR_PARAMS_t<ftyp> params; 3644 + int error_occurred = get_fp_invalid_and_clear(); 3645 + fortran_int n, m; 3646 + 3647 + INIT_OUTER_LOOP_3 3648 + 3649 + m = (fortran_int)dimensions[0]; 3650 + n = (fortran_int)dimensions[1]; 3651 + 3652 + 3653 + if (init_gqr_complete(&params, m, n)) { 3654 + LINEARIZE_DATA_t a_in, tau_in, q_out; 3655 + 3656 + init_linearize_data(&a_in, n, m, steps[1], steps[0]); 3657 + init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]); 3658 + init_linearize_data(&q_out, m, m, steps[4], steps[3]); 3659 + 3660 + BEGIN_OUTER_LOOP_3 3661 + int not_ok; 3662 + linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 3663 + linearize_matrix((typ*)params.Q, (typ*)args[0], &a_in); 3664 + linearize_matrix((typ*)params.TAU, (typ*)args[1], &tau_in); 3665 + not_ok = call_gqr(&params); 3666 + if (!not_ok) { 3667 + delinearize_matrix((typ*)args[2], (typ*)params.Q, &q_out); 3668 + } else { 3669 + error_occurred = 1; 3670 + nan_matrix((typ*)args[2], &q_out); 3671 + } 3672 + END_OUTER_LOOP 3673 + 3674 + release_gqr(&params); 3675 + } 3676 + 3677 + set_fp_invalid_or_clear(error_occurred); 3678 + } 3679 + 3680 + /* -------------------------------------------------------------------------- */ 3681 + /* least squares */ 3682 + 3683 + template<typename typ> 3684 + struct GELSD_PARAMS_t 3685 + { 3686 + fortran_int M; 3687 + fortran_int N; 3688 + fortran_int NRHS; 3689 + typ *A; 3690 + fortran_int LDA; 3691 + typ *B; 3692 + fortran_int LDB; 3693 + basetype_t<typ> *S; 3694 + basetype_t<typ> *RCOND; 3695 + fortran_int RANK; 3696 + typ *WORK; 3697 + fortran_int LWORK; 3698 + basetype_t<typ> *RWORK; 3699 + fortran_int *IWORK; 3700 + }; 3701 + 3702 + template<typename typ> 3703 + static inline void 3704 + dump_gelsd_params(const char *name, 3705 + GELSD_PARAMS_t<typ> *params) 3706 + { 3707 + TRACE_TXT("\n%s:\n"\ 3708 + 3709 + "%14s: %18p\n"\ 3710 + "%14s: %18p\n"\ 3711 + "%14s: %18p\n"\ 3712 + "%14s: %18p\n"\ 3713 + "%14s: %18p\n"\ 3714 + "%14s: %18p\n"\ 3715 + 3716 + "%14s: %18d\n"\ 3717 + "%14s: %18d\n"\ 3718 + "%14s: %18d\n"\ 3719 + "%14s: %18d\n"\ 3720 + "%14s: %18d\n"\ 3721 + "%14s: %18d\n"\ 3722 + "%14s: %18d\n"\ 3723 + 3724 + "%14s: %18p\n", 3725 + 3726 + name, 3727 + 3728 + "A", params->A, 3729 + "B", params->B, 3730 + "S", params->S, 3731 + "WORK", params->WORK, 3732 + "RWORK", params->RWORK, 3733 + "IWORK", params->IWORK, 3734 + 3735 + "M", (int)params->M, 3736 + "N", (int)params->N, 3737 + "NRHS", (int)params->NRHS, 3738 + "LDA", (int)params->LDA, 3739 + "LDB", (int)params->LDB, 3740 + "LWORK", (int)params->LWORK, 3741 + "RANK", (int)params->RANK, 3742 + 3743 + "RCOND", params->RCOND); 3744 + } 3745 + 3746 + static inline fortran_int 3747 + call_gelsd(GELSD_PARAMS_t<fortran_real> *params) 3748 + { 3749 + fortran_int rv; 3750 + LAPACK(sgelsd)(&params->M, &params->N, &params->NRHS, 3751 + params->A, &params->LDA, 3752 + params->B, &params->LDB, 3753 + params->S, 3754 + params->RCOND, &params->RANK, 3755 + params->WORK, &params->LWORK, 3756 + params->IWORK, 3757 + &rv); 3758 + return rv; 3759 + } 3760 + 3761 + 3762 + static inline fortran_int 3763 + call_gelsd(GELSD_PARAMS_t<fortran_doublereal> *params) 3764 + { 3765 + fortran_int rv; 3766 + LAPACK(dgelsd)(&params->M, &params->N, &params->NRHS, 3767 + params->A, &params->LDA, 3768 + params->B, &params->LDB, 3769 + params->S, 3770 + params->RCOND, &params->RANK, 3771 + params->WORK, &params->LWORK, 3772 + params->IWORK, 3773 + &rv); 3774 + return rv; 3775 + } 3776 + 3777 + 3778 + template<typename ftyp> 3779 + static inline int 3780 + init_gelsd(GELSD_PARAMS_t<ftyp> *params, 3781 + fortran_int m, 3782 + fortran_int n, 3783 + fortran_int nrhs, 3784 + scalar_trait) 3785 + { 3786 + npy_uint8 *mem_buff = NULL; 3787 + npy_uint8 *mem_buff2 = NULL; 3788 + npy_uint8 *a, *b, *s, *work, *iwork; 3789 + fortran_int min_m_n = fortran_int_min(m, n); 3790 + fortran_int max_m_n = fortran_int_max(m, n); 3791 + size_t safe_min_m_n = min_m_n; 3792 + size_t safe_max_m_n = max_m_n; 3793 + size_t safe_m = m; 3794 + size_t safe_n = n; 3795 + size_t safe_nrhs = nrhs; 3796 + 3797 + size_t a_size = safe_m * safe_n * sizeof(ftyp); 3798 + size_t b_size = safe_max_m_n * safe_nrhs * sizeof(ftyp); 3799 + size_t s_size = safe_min_m_n * sizeof(ftyp); 3800 + 3801 + fortran_int work_count; 3802 + size_t work_size; 3803 + size_t iwork_size; 3804 + fortran_int lda = fortran_int_max(1, m); 3805 + fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n)); 3806 + 3807 + size_t msize = a_size + b_size + s_size; 3808 + mem_buff = (npy_uint8 *)malloc(msize != 0 ? msize : 1); 3809 + 3810 + if (!mem_buff) { 3811 + goto no_memory; 3812 + } 3813 + a = mem_buff; 3814 + b = a + a_size; 3815 + s = b + b_size; 3816 + 3817 + params->M = m; 3818 + params->N = n; 3819 + params->NRHS = nrhs; 3820 + params->A = (ftyp*)a; 3821 + params->B = (ftyp*)b; 3822 + params->S = (ftyp*)s; 3823 + params->LDA = lda; 3824 + params->LDB = ldb; 3825 + 3826 + { 3827 + /* compute optimal work size */ 3828 + ftyp work_size_query; 3829 + fortran_int iwork_size_query; 3830 + 3831 + params->WORK = &work_size_query; 3832 + params->IWORK = &iwork_size_query; 3833 + params->RWORK = NULL; 3834 + params->LWORK = -1; 3835 + 3836 + if (call_gelsd(params) != 0) { 3837 + goto error; 3838 + } 3839 + work_count = (fortran_int)work_size_query; 3840 + 3841 + work_size = (size_t) work_size_query * sizeof(ftyp); 3842 + iwork_size = (size_t)iwork_size_query * sizeof(fortran_int); 3843 + } 3844 + 3845 + mem_buff2 = (npy_uint8 *)malloc(work_size + iwork_size); 3846 + if (!mem_buff2) { 3847 + goto no_memory; 3848 + } 3849 + work = mem_buff2; 3850 + iwork = work + work_size; 3851 + 3852 + params->WORK = (ftyp*)work; 3853 + params->RWORK = NULL; 3854 + params->IWORK = (fortran_int*)iwork; 3855 + params->LWORK = work_count; 3856 + 3857 + return 1; 3858 + 3859 + no_memory: 3860 + NPY_ALLOW_C_API_DEF 3861 + NPY_ALLOW_C_API; 3862 + PyErr_NoMemory(); 3863 + NPY_DISABLE_C_API; 3864 + 3865 + error: 3866 + TRACE_TXT("%s failed init\n", __FUNCTION__); 3867 + free(mem_buff); 3868 + free(mem_buff2); 3869 + memset(params, 0, sizeof(*params)); 3870 + return 0; 3871 + } 3872 + 3873 + static inline fortran_int 3874 + call_gelsd(GELSD_PARAMS_t<fortran_complex> *params) 3875 + { 3876 + fortran_int rv; 3877 + LAPACK(cgelsd)(&params->M, &params->N, &params->NRHS, 3878 + params->A, &params->LDA, 3879 + params->B, &params->LDB, 3880 + params->S, 3881 + params->RCOND, &params->RANK, 3882 + params->WORK, &params->LWORK, 3883 + params->RWORK, (fortran_int*)params->IWORK, 3884 + &rv); 3885 + return rv; 3886 + } 3887 + 3888 + static inline fortran_int 3889 + call_gelsd(GELSD_PARAMS_t<fortran_doublecomplex> *params) 3890 + { 3891 + fortran_int rv; 3892 + LAPACK(zgelsd)(&params->M, &params->N, &params->NRHS, 3893 + params->A, &params->LDA, 3894 + params->B, &params->LDB, 3895 + params->S, 3896 + params->RCOND, &params->RANK, 3897 + params->WORK, &params->LWORK, 3898 + params->RWORK, (fortran_int*)params->IWORK, 3899 + &rv); 3900 + return rv; 3901 + } 3902 + 3903 + 3904 + template<typename ftyp> 3905 + static inline int 3906 + init_gelsd(GELSD_PARAMS_t<ftyp> *params, 3907 + fortran_int m, 3908 + fortran_int n, 3909 + fortran_int nrhs, 3910 + complex_trait) 3911 + { 3912 + using frealtyp = basetype_t<ftyp>; 3913 + npy_uint8 *mem_buff = NULL; 3914 + npy_uint8 *mem_buff2 = NULL; 3915 + npy_uint8 *a, *b, *s, *work, *iwork, *rwork; 3916 + fortran_int min_m_n = fortran_int_min(m, n); 3917 + fortran_int max_m_n = fortran_int_max(m, n); 3918 + size_t safe_min_m_n = min_m_n; 3919 + size_t safe_max_m_n = max_m_n; 3920 + size_t safe_m = m; 3921 + size_t safe_n = n; 3922 + size_t safe_nrhs = nrhs; 3923 + 3924 + size_t a_size = safe_m * safe_n * sizeof(ftyp); 3925 + size_t b_size = safe_max_m_n * safe_nrhs * sizeof(ftyp); 3926 + size_t s_size = safe_min_m_n * sizeof(frealtyp); 3927 + 3928 + fortran_int work_count; 3929 + size_t work_size, rwork_size, iwork_size; 3930 + fortran_int lda = fortran_int_max(1, m); 3931 + fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n)); 3932 + 3933 + size_t msize = a_size + b_size + s_size; 3934 + mem_buff = (npy_uint8 *)malloc(msize != 0 ? msize : 1); 3935 + 3936 + if (!mem_buff) { 3937 + goto no_memory; 3938 + } 3939 + 3940 + a = mem_buff; 3941 + b = a + a_size; 3942 + s = b + b_size; 3943 + 3944 + params->M = m; 3945 + params->N = n; 3946 + params->NRHS = nrhs; 3947 + params->A = (ftyp*)a; 3948 + params->B = (ftyp*)b; 3949 + params->S = (frealtyp*)s; 3950 + params->LDA = lda; 3951 + params->LDB = ldb; 3952 + 3953 + { 3954 + /* compute optimal work size */ 3955 + ftyp work_size_query; 3956 + frealtyp rwork_size_query; 3957 + fortran_int iwork_size_query; 3958 + 3959 + params->WORK = &work_size_query; 3960 + params->IWORK = &iwork_size_query; 3961 + params->RWORK = &rwork_size_query; 3962 + params->LWORK = -1; 3963 + 3964 + if (call_gelsd(params) != 0) { 3965 + goto error; 3966 + } 3967 + 3968 + work_count = (fortran_int)work_size_query.r; 3969 + 3970 + work_size = (size_t )work_size_query.r * sizeof(ftyp); 3971 + rwork_size = (size_t)rwork_size_query * sizeof(frealtyp); 3972 + iwork_size = (size_t)iwork_size_query * sizeof(fortran_int); 3973 + } 3974 + 3975 + mem_buff2 = (npy_uint8 *)malloc(work_size + rwork_size + iwork_size); 3976 + if (!mem_buff2) { 3977 + goto no_memory; 3978 + } 3979 + 3980 + work = mem_buff2; 3981 + rwork = work + work_size; 3982 + iwork = rwork + rwork_size; 3983 + 3984 + params->WORK = (ftyp*)work; 3985 + params->RWORK = (frealtyp*)rwork; 3986 + params->IWORK = (fortran_int*)iwork; 3987 + params->LWORK = work_count; 3988 + 3989 + return 1; 3990 + 3991 + no_memory: 3992 + NPY_ALLOW_C_API_DEF 3993 + NPY_ALLOW_C_API; 3994 + PyErr_NoMemory(); 3995 + NPY_DISABLE_C_API; 3996 + 3997 + error: 3998 + TRACE_TXT("%s failed init\n", __FUNCTION__); 3999 + free(mem_buff); 4000 + free(mem_buff2); 4001 + memset(params, 0, sizeof(*params)); 4002 + 4003 + return 0; 4004 + } 4005 + 4006 + template<typename ftyp> 4007 + static inline void 4008 + release_gelsd(GELSD_PARAMS_t<ftyp>* params) 4009 + { 4010 + /* A and WORK contain allocated blocks */ 4011 + free(params->A); 4012 + free(params->WORK); 4013 + memset(params, 0, sizeof(*params)); 4014 + } 4015 + 4016 + /** Compute the squared l2 norm of a contiguous vector */ 4017 + template<typename typ> 4018 + static basetype_t<typ> 4019 + abs2(typ *p, npy_intp n, scalar_trait) { 4020 + npy_intp i; 4021 + basetype_t<typ> res = 0; 4022 + for (i = 0; i < n; i++) { 4023 + typ el = p[i]; 4024 + res += el*el; 4025 + } 4026 + return res; 4027 + } 4028 + template<typename typ> 4029 + static basetype_t<typ> 4030 + abs2(typ *p, npy_intp n, complex_trait) { 4031 + npy_intp i; 4032 + basetype_t<typ> res = 0; 4033 + for (i = 0; i < n; i++) { 4034 + typ el = p[i]; 4035 + res += RE(&el)*RE(&el) + IM(&el)*IM(&el); 4036 + } 4037 + return res; 4038 + } 4039 + 4040 + 4041 + template<typename typ> 4042 + static void 4043 + lstsq(char **args, npy_intp const *dimensions, npy_intp const *steps, 4044 + void *NPY_UNUSED(func)) 4045 + { 4046 + using ftyp = fortran_type_t<typ>; 4047 + using basetyp = basetype_t<typ>; 4048 + GELSD_PARAMS_t<ftyp> params; 4049 + int error_occurred = get_fp_invalid_and_clear(); 4050 + fortran_int n, m, nrhs; 4051 + fortran_int excess; 4052 + 4053 + INIT_OUTER_LOOP_7 4054 + 4055 + m = (fortran_int)dimensions[0]; 4056 + n = (fortran_int)dimensions[1]; 4057 + nrhs = (fortran_int)dimensions[2]; 4058 + excess = m - n; 4059 + 4060 + if (init_gelsd(&params, m, n, nrhs, dispatch_scalar<ftyp>{})) { 4061 + LINEARIZE_DATA_t a_in, b_in, x_out, s_out, r_out; 4062 + 4063 + init_linearize_data(&a_in, n, m, steps[1], steps[0]); 4064 + init_linearize_data_ex(&b_in, nrhs, m, steps[3], steps[2], fortran_int_max(n, m)); 4065 + init_linearize_data_ex(&x_out, nrhs, n, steps[5], steps[4], fortran_int_max(n, m)); 4066 + init_linearize_data(&r_out, 1, nrhs, 1, steps[6]); 4067 + init_linearize_data(&s_out, 1, fortran_int_min(n, m), 1, steps[7]); 4068 + 4069 + BEGIN_OUTER_LOOP_7 4070 + int not_ok; 4071 + linearize_matrix((typ*)params.A, (typ*)args[0], &a_in); 4072 + linearize_matrix((typ*)params.B, (typ*)args[1], &b_in); 4073 + params.RCOND = (basetyp*)args[2]; 4074 + not_ok = call_gelsd(&params); 4075 + if (!not_ok) { 4076 + delinearize_matrix((typ*)args[3], (typ*)params.B, &x_out); 4077 + *(npy_int*) args[5] = params.RANK; 4078 + delinearize_matrix((basetyp*)args[6], (basetyp*)params.S, &s_out); 4079 + 4080 + /* Note that linalg.lstsq discards this when excess == 0 */ 4081 + if (excess >= 0 && params.RANK == n) { 4082 + /* Compute the residuals as the square sum of each column */ 4083 + int i; 4084 + char *resid = args[4]; 4085 + ftyp *components = (ftyp *)params.B + n; 4086 + for (i = 0; i < nrhs; i++) { 4087 + ftyp *vector = components + i*m; 4088 + /* Numpy and fortran floating types are the same size, 4089 + * so this cast is safe */ 4090 + basetyp abs = abs2((typ *)vector, excess, 4091 + dispatch_scalar<typ>{}); 4092 + memcpy( 4093 + resid + i*r_out.column_strides, 4094 + &abs, sizeof(abs)); 4095 + } 4096 + } 4097 + else { 4098 + /* Note that this is always discarded by linalg.lstsq */ 4099 + nan_matrix((basetyp*)args[4], &r_out); 4100 + } 4101 + } else { 4102 + error_occurred = 1; 4103 + nan_matrix((typ*)args[3], &x_out); 4104 + nan_matrix((basetyp*)args[4], &r_out); 4105 + *(npy_int*) args[5] = -1; 4106 + nan_matrix((basetyp*)args[6], &s_out); 4107 + } 4108 + END_OUTER_LOOP 4109 + 4110 + release_gelsd(&params); 4111 + } 4112 + 4113 + set_fp_invalid_or_clear(error_occurred); 4114 + } 4115 + 4116 + #pragma GCC diagnostic pop 4117 + 4118 + /* -------------------------------------------------------------------------- */ 4119 + /* gufunc registration */ 4120 + 4121 + static void *array_of_nulls[] = { 4122 + (void *)NULL, 4123 + (void *)NULL, 4124 + (void *)NULL, 4125 + (void *)NULL, 4126 + 4127 + (void *)NULL, 4128 + (void *)NULL, 4129 + (void *)NULL, 4130 + (void *)NULL, 4131 + 4132 + (void *)NULL, 4133 + (void *)NULL, 4134 + (void *)NULL, 4135 + (void *)NULL, 4136 + 4137 + (void *)NULL, 4138 + (void *)NULL, 4139 + (void *)NULL, 4140 + (void *)NULL 4141 + }; 4142 + 4143 + #define FUNC_ARRAY_NAME(NAME) NAME ## _funcs 4144 + 4145 + #define GUFUNC_FUNC_ARRAY_REAL(NAME) \ 4146 + static PyUFuncGenericFunction \ 4147 + FUNC_ARRAY_NAME(NAME)[] = { \ 4148 + FLOAT_ ## NAME, \ 4149 + DOUBLE_ ## NAME \ 4150 + } 4151 + 4152 + #define GUFUNC_FUNC_ARRAY_REAL_COMPLEX(NAME) \ 4153 + static PyUFuncGenericFunction \ 4154 + FUNC_ARRAY_NAME(NAME)[] = { \ 4155 + FLOAT_ ## NAME, \ 4156 + DOUBLE_ ## NAME, \ 4157 + CFLOAT_ ## NAME, \ 4158 + CDOUBLE_ ## NAME \ 4159 + } 4160 + #define GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(NAME) \ 4161 + static PyUFuncGenericFunction \ 4162 + FUNC_ARRAY_NAME(NAME)[] = { \ 4163 + NAME<npy_float, npy_float>, \ 4164 + NAME<npy_double, npy_double>, \ 4165 + NAME<npy_cfloat, npy_float>, \ 4166 + NAME<npy_cdouble, npy_double> \ 4167 + } 4168 + #define GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(NAME) \ 4169 + static PyUFuncGenericFunction \ 4170 + FUNC_ARRAY_NAME(NAME)[] = { \ 4171 + NAME<npy_float>, \ 4172 + NAME<npy_double>, \ 4173 + NAME<npy_cfloat>, \ 4174 + NAME<npy_cdouble> \ 4175 + } 4176 + 4177 + /* There are problems with eig in complex single precision. 4178 + * That kernel is disabled 4179 + */ 4180 + #define GUFUNC_FUNC_ARRAY_EIG(NAME) \ 4181 + static PyUFuncGenericFunction \ 4182 + FUNC_ARRAY_NAME(NAME)[] = { \ 4183 + NAME<fortran_complex,fortran_real>, \ 4184 + NAME<fortran_doublecomplex,fortran_doublereal>, \ 4185 + NAME<fortran_doublecomplex,fortran_doublecomplex> \ 4186 + } 4187 + 4188 + /* The single precision functions are not used at all, 4189 + * due to input data being promoted to double precision 4190 + * in Python, so they are not implemented here. 4191 + */ 4192 + #define GUFUNC_FUNC_ARRAY_QR(NAME) \ 4193 + static PyUFuncGenericFunction \ 4194 + FUNC_ARRAY_NAME(NAME)[] = { \ 4195 + DOUBLE_ ## NAME, \ 4196 + CDOUBLE_ ## NAME \ 4197 + } 4198 + #define GUFUNC_FUNC_ARRAY_QR__(NAME) \ 4199 + static PyUFuncGenericFunction \ 4200 + FUNC_ARRAY_NAME(NAME)[] = { \ 4201 + NAME<npy_double>, \ 4202 + NAME<npy_cdouble> \ 4203 + } 4204 + 4205 + 4206 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(slogdet); 4207 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(det); 4208 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eighlo); 4209 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eighup); 4210 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eigvalshlo); 4211 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eigvalshup); 4212 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(solve); 4213 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(solve1); 4214 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(inv); 4215 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(cholesky_lo); 4216 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_N); 4217 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_S); 4218 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_A); 4219 + GUFUNC_FUNC_ARRAY_QR__(qr_r_raw); 4220 + GUFUNC_FUNC_ARRAY_QR__(qr_reduced); 4221 + GUFUNC_FUNC_ARRAY_QR__(qr_complete); 4222 + GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(lstsq); 4223 + GUFUNC_FUNC_ARRAY_EIG(eig); 4224 + GUFUNC_FUNC_ARRAY_EIG(eigvals); 4225 + 4226 + static char equal_2_types[] = { 4227 + NPY_FLOAT, NPY_FLOAT, 4228 + NPY_DOUBLE, NPY_DOUBLE, 4229 + NPY_CFLOAT, NPY_CFLOAT, 4230 + NPY_CDOUBLE, NPY_CDOUBLE 4231 + }; 4232 + 4233 + static char equal_3_types[] = { 4234 + NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, 4235 + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4236 + NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT, 4237 + NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE 4238 + }; 4239 + 4240 + /* second result is logdet, that will always be a REAL */ 4241 + static char slogdet_types[] = { 4242 + NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, 4243 + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4244 + NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, 4245 + NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE 4246 + }; 4247 + 4248 + static char eigh_types[] = { 4249 + NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, 4250 + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4251 + NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, 4252 + NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE 4253 + }; 4254 + 4255 + static char eighvals_types[] = { 4256 + NPY_FLOAT, NPY_FLOAT, 4257 + NPY_DOUBLE, NPY_DOUBLE, 4258 + NPY_CFLOAT, NPY_FLOAT, 4259 + NPY_CDOUBLE, NPY_DOUBLE 4260 + }; 4261 + 4262 + static char eig_types[] = { 4263 + NPY_FLOAT, NPY_CFLOAT, NPY_CFLOAT, 4264 + NPY_DOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, 4265 + NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE 4266 + }; 4267 + 4268 + static char eigvals_types[] = { 4269 + NPY_FLOAT, NPY_CFLOAT, 4270 + NPY_DOUBLE, NPY_CDOUBLE, 4271 + NPY_CDOUBLE, NPY_CDOUBLE 4272 + }; 4273 + 4274 + static char svd_1_1_types[] = { 4275 + NPY_FLOAT, NPY_FLOAT, 4276 + NPY_DOUBLE, NPY_DOUBLE, 4277 + NPY_CFLOAT, NPY_FLOAT, 4278 + NPY_CDOUBLE, NPY_DOUBLE 4279 + }; 4280 + 4281 + static char svd_1_3_types[] = { 4282 + NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, 4283 + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4284 + NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, 4285 + NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE 4286 + }; 4287 + 4288 + /* A, tau */ 4289 + static char qr_r_raw_types[] = { 4290 + NPY_DOUBLE, NPY_DOUBLE, 4291 + NPY_CDOUBLE, NPY_CDOUBLE, 4292 + }; 4293 + 4294 + /* A, tau, q */ 4295 + static char qr_reduced_types[] = { 4296 + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4297 + NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, 4298 + }; 4299 + 4300 + /* A, tau, q */ 4301 + static char qr_complete_types[] = { 4302 + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, 4303 + NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, 4304 + }; 4305 + 4306 + /* A, b, rcond, x, resid, rank, s, */ 4307 + static char lstsq_types[] = { 4308 + NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT, 4309 + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE, 4310 + NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT, 4311 + NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE, 4312 + }; 4313 + 4314 + typedef struct gufunc_descriptor_struct { 4315 + const char *name; 4316 + const char *signature; 4317 + const char *doc; 4318 + int ntypes; 4319 + int nin; 4320 + int nout; 4321 + PyUFuncGenericFunction *funcs; 4322 + char *types; 4323 + } GUFUNC_DESCRIPTOR_t; 4324 + 4325 + GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = { 4326 + { 4327 + "slogdet", 4328 + "(m,m)->(),()", 4329 + "slogdet on the last two dimensions and broadcast on the rest. \n"\ 4330 + "Results in two arrays, one with sign and the other with log of the"\ 4331 + " determinants. \n"\ 4332 + " \"(m,m)->(),()\" \n", 4333 + 4, 1, 2, 4334 + FUNC_ARRAY_NAME(slogdet), 4335 + slogdet_types 4336 + }, 4337 + { 4338 + "det", 4339 + "(m,m)->()", 4340 + "det of the last two dimensions and broadcast on the rest. \n"\ 4341 + " \"(m,m)->()\" \n", 4342 + 4, 1, 1, 4343 + FUNC_ARRAY_NAME(det), 4344 + equal_2_types 4345 + }, 4346 + { 4347 + "eigh_lo", 4348 + "(m,m)->(m),(m,m)", 4349 + "eigh on the last two dimension and broadcast to the rest, using"\ 4350 + " lower triangle \n"\ 4351 + "Results in a vector of eigenvalues and a matrix with the"\ 4352 + "eigenvectors. \n"\ 4353 + " \"(m,m)->(m),(m,m)\" \n", 4354 + 4, 1, 2, 4355 + FUNC_ARRAY_NAME(eighlo), 4356 + eigh_types 4357 + }, 4358 + { 4359 + "eigh_up", 4360 + "(m,m)->(m),(m,m)", 4361 + "eigh on the last two dimension and broadcast to the rest, using"\ 4362 + " upper triangle. \n"\ 4363 + "Results in a vector of eigenvalues and a matrix with the"\ 4364 + " eigenvectors. \n"\ 4365 + " \"(m,m)->(m),(m,m)\" \n", 4366 + 4, 1, 2, 4367 + FUNC_ARRAY_NAME(eighup), 4368 + eigh_types 4369 + }, 4370 + { 4371 + "eigvalsh_lo", 4372 + "(m,m)->(m)", 4373 + "eigh on the last two dimension and broadcast to the rest, using"\ 4374 + " lower triangle. \n"\ 4375 + "Results in a vector of eigenvalues and a matrix with the"\ 4376 + "eigenvectors. \n"\ 4377 + " \"(m,m)->(m)\" \n", 4378 + 4, 1, 1, 4379 + FUNC_ARRAY_NAME(eigvalshlo), 4380 + eighvals_types 4381 + }, 4382 + { 4383 + "eigvalsh_up", 4384 + "(m,m)->(m)", 4385 + "eigvalsh on the last two dimension and broadcast to the rest,"\ 4386 + " using upper triangle. \n"\ 4387 + "Results in a vector of eigenvalues and a matrix with the"\ 4388 + "eigenvectors.\n"\ 4389 + " \"(m,m)->(m)\" \n", 4390 + 4, 1, 1, 4391 + FUNC_ARRAY_NAME(eigvalshup), 4392 + eighvals_types 4393 + }, 4394 + { 4395 + "solve", 4396 + "(m,m),(m,n)->(m,n)", 4397 + "solve the system a x = b, on the last two dimensions, broadcast"\ 4398 + " to the rest. \n"\ 4399 + "Results in a matrices with the solutions. \n"\ 4400 + " \"(m,m),(m,n)->(m,n)\" \n", 4401 + 4, 2, 1, 4402 + FUNC_ARRAY_NAME(solve), 4403 + equal_3_types 4404 + }, 4405 + { 4406 + "solve1", 4407 + "(m,m),(m)->(m)", 4408 + "solve the system a x = b, for b being a vector, broadcast in"\ 4409 + " the outer dimensions. \n"\ 4410 + "Results in vectors with the solutions. \n"\ 4411 + " \"(m,m),(m)->(m)\" \n", 4412 + 4, 2, 1, 4413 + FUNC_ARRAY_NAME(solve1), 4414 + equal_3_types 4415 + }, 4416 + { 4417 + "inv", 4418 + "(m, m)->(m, m)", 4419 + "compute the inverse of the last two dimensions and broadcast"\ 4420 + " to the rest. \n"\ 4421 + "Results in the inverse matrices. \n"\ 4422 + " \"(m,m)->(m,m)\" \n", 4423 + 4, 1, 1, 4424 + FUNC_ARRAY_NAME(inv), 4425 + equal_2_types 4426 + }, 4427 + { 4428 + "cholesky_lo", 4429 + "(m,m)->(m,m)", 4430 + "cholesky decomposition of hermitian positive-definite matrices. \n"\ 4431 + "Broadcast to all outer dimensions. \n"\ 4432 + " \"(m,m)->(m,m)\" \n", 4433 + 4, 1, 1, 4434 + FUNC_ARRAY_NAME(cholesky_lo), 4435 + equal_2_types 4436 + }, 4437 + { 4438 + "svd_m", 4439 + "(m,n)->(m)", 4440 + "svd when n>=m. ", 4441 + 4, 1, 1, 4442 + FUNC_ARRAY_NAME(svd_N), 4443 + svd_1_1_types 4444 + }, 4445 + { 4446 + "svd_n", 4447 + "(m,n)->(n)", 4448 + "svd when n<=m", 4449 + 4, 1, 1, 4450 + FUNC_ARRAY_NAME(svd_N), 4451 + svd_1_1_types 4452 + }, 4453 + { 4454 + "svd_m_s", 4455 + "(m,n)->(m,m),(m),(m,n)", 4456 + "svd when m<=n", 4457 + 4, 1, 3, 4458 + FUNC_ARRAY_NAME(svd_S), 4459 + svd_1_3_types 4460 + }, 4461 + { 4462 + "svd_n_s", 4463 + "(m,n)->(m,n),(n),(n,n)", 4464 + "svd when m>=n", 4465 + 4, 1, 3, 4466 + FUNC_ARRAY_NAME(svd_S), 4467 + svd_1_3_types 4468 + }, 4469 + { 4470 + "svd_m_f", 4471 + "(m,n)->(m,m),(m),(n,n)", 4472 + "svd when m<=n", 4473 + 4, 1, 3, 4474 + FUNC_ARRAY_NAME(svd_A), 4475 + svd_1_3_types 4476 + }, 4477 + { 4478 + "svd_n_f", 4479 + "(m,n)->(m,m),(n),(n,n)", 4480 + "svd when m>=n", 4481 + 4, 1, 3, 4482 + FUNC_ARRAY_NAME(svd_A), 4483 + svd_1_3_types 4484 + }, 4485 + { 4486 + "eig", 4487 + "(m,m)->(m),(m,m)", 4488 + "eig on the last two dimension and broadcast to the rest. \n"\ 4489 + "Results in a vector with the eigenvalues and a matrix with the"\ 4490 + " eigenvectors. \n"\ 4491 + " \"(m,m)->(m),(m,m)\" \n", 4492 + 3, 1, 2, 4493 + FUNC_ARRAY_NAME(eig), 4494 + eig_types 4495 + }, 4496 + { 4497 + "eigvals", 4498 + "(m,m)->(m)", 4499 + "eigvals on the last two dimension and broadcast to the rest. \n"\ 4500 + "Results in a vector of eigenvalues. \n", 4501 + 3, 1, 1, 4502 + FUNC_ARRAY_NAME(eigvals), 4503 + eigvals_types 4504 + }, 4505 + { 4506 + "qr_r_raw_m", 4507 + "(m,n)->(m)", 4508 + "Compute TAU vector for the last two dimensions \n"\ 4509 + "and broadcast to the rest. For m <= n. \n", 4510 + 2, 1, 1, 4511 + FUNC_ARRAY_NAME(qr_r_raw), 4512 + qr_r_raw_types 4513 + }, 4514 + { 4515 + "qr_r_raw_n", 4516 + "(m,n)->(n)", 4517 + "Compute TAU vector for the last two dimensions \n"\ 4518 + "and broadcast to the rest. For m > n. \n", 4519 + 2, 1, 1, 4520 + FUNC_ARRAY_NAME(qr_r_raw), 4521 + qr_r_raw_types 4522 + }, 4523 + { 4524 + "qr_reduced", 4525 + "(m,n),(k)->(m,k)", 4526 + "Compute Q matrix for the last two dimensions \n"\ 4527 + "and broadcast to the rest. \n", 4528 + 2, 2, 1, 4529 + FUNC_ARRAY_NAME(qr_reduced), 4530 + qr_reduced_types 4531 + }, 4532 + { 4533 + "qr_complete", 4534 + "(m,n),(n)->(m,m)", 4535 + "Compute Q matrix for the last two dimensions \n"\ 4536 + "and broadcast to the rest. For m > n. \n", 4537 + 2, 2, 1, 4538 + FUNC_ARRAY_NAME(qr_complete), 4539 + qr_complete_types 4540 + }, 4541 + { 4542 + "lstsq_m", 4543 + "(m,n),(m,nrhs),()->(n,nrhs),(nrhs),(),(m)", 4544 + "least squares on the last two dimensions and broadcast to the rest. \n"\ 4545 + "For m <= n. \n", 4546 + 4, 3, 4, 4547 + FUNC_ARRAY_NAME(lstsq), 4548 + lstsq_types 4549 + }, 4550 + { 4551 + "lstsq_n", 4552 + "(m,n),(m,nrhs),()->(n,nrhs),(nrhs),(),(n)", 4553 + "least squares on the last two dimensions and broadcast to the rest. \n"\ 4554 + "For m >= n, meaning that residuals are produced. \n", 4555 + 4, 3, 4, 4556 + FUNC_ARRAY_NAME(lstsq), 4557 + lstsq_types 4558 + } 4559 + }; 4560 + 4561 + static int 4562 + addUfuncs(PyObject *dictionary) { 4563 + PyObject *f; 4564 + int i; 4565 + const int gufunc_count = sizeof(gufunc_descriptors)/ 4566 + sizeof(gufunc_descriptors[0]); 4567 + for (i = 0; i < gufunc_count; i++) { 4568 + GUFUNC_DESCRIPTOR_t* d = &gufunc_descriptors[i]; 4569 + f = PyUFunc_FromFuncAndDataAndSignature(d->funcs, 4570 + array_of_nulls, 4571 + d->types, 4572 + d->ntypes, 4573 + d->nin, 4574 + d->nout, 4575 + PyUFunc_None, 4576 + d->name, 4577 + d->doc, 4578 + 0, 4579 + d->signature); 4580 + if (f == NULL) { 4581 + return -1; 4582 + } 4583 + #if 0 4584 + dump_ufunc_object((PyUFuncObject*) f); 4585 + #endif 4586 + int ret = PyDict_SetItemString(dictionary, d->name, f); 4587 + Py_DECREF(f); 4588 + if (ret < 0) { 4589 + return -1; 4590 + } 4591 + } 4592 + return 0; 4593 + } 4594 + 4595 + 4596 + 4597 + /* -------------------------------------------------------------------------- */ 4598 + /* Module initialization stuff */ 4599 + 4600 + static PyMethodDef UMath_LinAlgMethods[] = { 4601 + {NULL, NULL, 0, NULL} /* Sentinel */ 4602 + }; 4603 + 4604 + static struct PyModuleDef moduledef = { 4605 + PyModuleDef_HEAD_INIT, 4606 + UMATH_LINALG_MODULE_NAME, 4607 + NULL, 4608 + -1, 4609 + UMath_LinAlgMethods, 4610 + NULL, 4611 + NULL, 4612 + NULL, 4613 + NULL 4614 + }; 4615 + 4616 + PyMODINIT_FUNC PyInit__umath_linalg(void) 4617 + { 4618 + PyObject *m; 4619 + PyObject *d; 4620 + PyObject *version; 4621 + 4622 + m = PyModule_Create(&moduledef); 4623 + if (m == NULL) { 4624 + return NULL; 4625 + } 4626 + 4627 + import_array(); 4628 + import_ufunc(); 4629 + 4630 + d = PyModule_GetDict(m); 4631 + if (d == NULL) { 4632 + return NULL; 4633 + } 4634 + 4635 + version = PyUnicode_FromString(umath_linalg_version_string); 4636 + if (version == NULL) { 4637 + return NULL; 4638 + } 4639 + int ret = PyDict_SetItemString(d, "__version__", version); 4640 + Py_DECREF(version); 4641 + if (ret < 0) { 4642 + return NULL; 4643 + } 4644 + 4645 + /* Load the ufunc operators into the module's namespace */ 4646 + if (addUfuncs(d) < 0) { 4647 + return NULL; 4648 + } 4649 + 4650 + #ifdef HAVE_BLAS_ILP64 4651 + PyDict_SetItemString(d, "_ilp64", Py_True); 4652 + #else 4653 + PyDict_SetItemString(d, "_ilp64", Py_False); 4654 + #endif 4655 + 4656 + return m; 4657 + }