fork of https://github.com/sourcegraph/zoekt
1/* -*- c -*- */
2
3/*
4 *****************************************************************************
5 ** INCLUDES **
6 *****************************************************************************
7 */
8#define PY_SSIZE_T_CLEAN
9#include <Python.h>
10
11#define NPY_NO_DEPRECATED_API NPY_API_VERSION
12#include "numpy/arrayobject.h"
13#include "numpy/ufuncobject.h"
14#include "numpy/npy_math.h"
15
16#include "npy_pycompat.h"
17
18#include "npy_config.h"
19
20#include "npy_cblas.h"
21
22#include <cstddef>
23#include <cstdio>
24#include <cassert>
25#include <cmath>
26#include <type_traits>
27#include <utility>
28
29
30static const char* umath_linalg_version_string = "0.1.5";
31
32/*
33 ****************************************************************************
34 * Debugging support *
35 ****************************************************************************
36 */
37#define TRACE_TXT(...) do { fprintf (stderr, __VA_ARGS__); } while (0)
38#define STACK_TRACE do {} while (0)
39#define TRACE\
40 do { \
41 fprintf (stderr, \
42 "%s:%d:%s\n", \
43 __FILE__, \
44 __LINE__, \
45 __FUNCTION__); \
46 STACK_TRACE; \
47 } while (0)
48
49#if 0
50#if defined HAVE_EXECINFO_H
51#include <execinfo.h>
52#elif defined HAVE_LIBUNWIND_H
53#include <libunwind.h>
54#endif
55void
56dbg_stack_trace()
57{
58 void *trace[32];
59 size_t size;
60
61 size = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
62 backtrace_symbols_fd(trace, size, 1);
63}
64
65#undef STACK_TRACE
66#define STACK_TRACE do { dbg_stack_trace(); } while (0)
67#endif
68
69/*
70 *****************************************************************************
71 * BLAS/LAPACK calling macros *
72 *****************************************************************************
73 */
74
75#define FNAME(x) BLAS_FUNC(x)
76
77typedef CBLAS_INT fortran_int;
78
79typedef struct { float r, i; } f2c_complex;
80typedef struct { double r, i; } f2c_doublecomplex;
81/* typedef long int (*L_fp)(); */
82
83typedef float fortran_real;
84typedef double fortran_doublereal;
85typedef f2c_complex fortran_complex;
86typedef f2c_doublecomplex fortran_doublecomplex;
87
88extern "C" fortran_int
89FNAME(sgeev)(char *jobvl, char *jobvr, fortran_int *n,
90 float a[], fortran_int *lda, float wr[], float wi[],
91 float vl[], fortran_int *ldvl, float vr[], fortran_int *ldvr,
92 float work[], fortran_int lwork[],
93 fortran_int *info);
94extern "C" fortran_int
95FNAME(dgeev)(char *jobvl, char *jobvr, fortran_int *n,
96 double a[], fortran_int *lda, double wr[], double wi[],
97 double vl[], fortran_int *ldvl, double vr[], fortran_int *ldvr,
98 double work[], fortran_int lwork[],
99 fortran_int *info);
100extern "C" fortran_int
101FNAME(cgeev)(char *jobvl, char *jobvr, fortran_int *n,
102 f2c_complex a[], fortran_int *lda,
103 f2c_complex w[],
104 f2c_complex vl[], fortran_int *ldvl,
105 f2c_complex vr[], fortran_int *ldvr,
106 f2c_complex work[], fortran_int *lwork,
107 float rwork[],
108 fortran_int *info);
109extern "C" fortran_int
110FNAME(zgeev)(char *jobvl, char *jobvr, fortran_int *n,
111 f2c_doublecomplex a[], fortran_int *lda,
112 f2c_doublecomplex w[],
113 f2c_doublecomplex vl[], fortran_int *ldvl,
114 f2c_doublecomplex vr[], fortran_int *ldvr,
115 f2c_doublecomplex work[], fortran_int *lwork,
116 double rwork[],
117 fortran_int *info);
118
119extern "C" fortran_int
120FNAME(ssyevd)(char *jobz, char *uplo, fortran_int *n,
121 float a[], fortran_int *lda, float w[], float work[],
122 fortran_int *lwork, fortran_int iwork[], fortran_int *liwork,
123 fortran_int *info);
124extern "C" fortran_int
125FNAME(dsyevd)(char *jobz, char *uplo, fortran_int *n,
126 double a[], fortran_int *lda, double w[], double work[],
127 fortran_int *lwork, fortran_int iwork[], fortran_int *liwork,
128 fortran_int *info);
129extern "C" fortran_int
130FNAME(cheevd)(char *jobz, char *uplo, fortran_int *n,
131 f2c_complex a[], fortran_int *lda,
132 float w[], f2c_complex work[],
133 fortran_int *lwork, float rwork[], fortran_int *lrwork, fortran_int iwork[],
134 fortran_int *liwork,
135 fortran_int *info);
136extern "C" fortran_int
137FNAME(zheevd)(char *jobz, char *uplo, fortran_int *n,
138 f2c_doublecomplex a[], fortran_int *lda,
139 double w[], f2c_doublecomplex work[],
140 fortran_int *lwork, double rwork[], fortran_int *lrwork, fortran_int iwork[],
141 fortran_int *liwork,
142 fortran_int *info);
143
144extern "C" fortran_int
145FNAME(sgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs,
146 float a[], fortran_int *lda, float b[], fortran_int *ldb,
147 float s[], float *rcond, fortran_int *rank,
148 float work[], fortran_int *lwork, fortran_int iwork[],
149 fortran_int *info);
150extern "C" fortran_int
151FNAME(dgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs,
152 double a[], fortran_int *lda, double b[], fortran_int *ldb,
153 double s[], double *rcond, fortran_int *rank,
154 double work[], fortran_int *lwork, fortran_int iwork[],
155 fortran_int *info);
156extern "C" fortran_int
157FNAME(cgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs,
158 f2c_complex a[], fortran_int *lda,
159 f2c_complex b[], fortran_int *ldb,
160 float s[], float *rcond, fortran_int *rank,
161 f2c_complex work[], fortran_int *lwork,
162 float rwork[], fortran_int iwork[],
163 fortran_int *info);
164extern "C" fortran_int
165FNAME(zgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs,
166 f2c_doublecomplex a[], fortran_int *lda,
167 f2c_doublecomplex b[], fortran_int *ldb,
168 double s[], double *rcond, fortran_int *rank,
169 f2c_doublecomplex work[], fortran_int *lwork,
170 double rwork[], fortran_int iwork[],
171 fortran_int *info);
172
173extern "C" fortran_int
174FNAME(dgeqrf)(fortran_int *m, fortran_int *n, double a[], fortran_int *lda,
175 double tau[], double work[],
176 fortran_int *lwork, fortran_int *info);
177extern "C" fortran_int
178FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int *lda,
179 f2c_doublecomplex tau[], f2c_doublecomplex work[],
180 fortran_int *lwork, fortran_int *info);
181
182extern "C" fortran_int
183FNAME(dorgqr)(fortran_int *m, fortran_int *n, fortran_int *k, double a[], fortran_int *lda,
184 double tau[], double work[],
185 fortran_int *lwork, fortran_int *info);
186extern "C" fortran_int
187FNAME(zungqr)(fortran_int *m, fortran_int *n, fortran_int *k, f2c_doublecomplex a[],
188 fortran_int *lda, f2c_doublecomplex tau[],
189 f2c_doublecomplex work[], fortran_int *lwork, fortran_int *info);
190
191extern "C" fortran_int
192FNAME(sgesv)(fortran_int *n, fortran_int *nrhs,
193 float a[], fortran_int *lda,
194 fortran_int ipiv[],
195 float b[], fortran_int *ldb,
196 fortran_int *info);
197extern "C" fortran_int
198FNAME(dgesv)(fortran_int *n, fortran_int *nrhs,
199 double a[], fortran_int *lda,
200 fortran_int ipiv[],
201 double b[], fortran_int *ldb,
202 fortran_int *info);
203extern "C" fortran_int
204FNAME(cgesv)(fortran_int *n, fortran_int *nrhs,
205 f2c_complex a[], fortran_int *lda,
206 fortran_int ipiv[],
207 f2c_complex b[], fortran_int *ldb,
208 fortran_int *info);
209extern "C" fortran_int
210FNAME(zgesv)(fortran_int *n, fortran_int *nrhs,
211 f2c_doublecomplex a[], fortran_int *lda,
212 fortran_int ipiv[],
213 f2c_doublecomplex b[], fortran_int *ldb,
214 fortran_int *info);
215
216extern "C" fortran_int
217FNAME(sgetrf)(fortran_int *m, fortran_int *n,
218 float a[], fortran_int *lda,
219 fortran_int ipiv[],
220 fortran_int *info);
221extern "C" fortran_int
222FNAME(dgetrf)(fortran_int *m, fortran_int *n,
223 double a[], fortran_int *lda,
224 fortran_int ipiv[],
225 fortran_int *info);
226extern "C" fortran_int
227FNAME(cgetrf)(fortran_int *m, fortran_int *n,
228 f2c_complex a[], fortran_int *lda,
229 fortran_int ipiv[],
230 fortran_int *info);
231extern "C" fortran_int
232FNAME(zgetrf)(fortran_int *m, fortran_int *n,
233 f2c_doublecomplex a[], fortran_int *lda,
234 fortran_int ipiv[],
235 fortran_int *info);
236
237extern "C" fortran_int
238FNAME(spotrf)(char *uplo, fortran_int *n,
239 float a[], fortran_int *lda,
240 fortran_int *info);
241extern "C" fortran_int
242FNAME(dpotrf)(char *uplo, fortran_int *n,
243 double a[], fortran_int *lda,
244 fortran_int *info);
245extern "C" fortran_int
246FNAME(cpotrf)(char *uplo, fortran_int *n,
247 f2c_complex a[], fortran_int *lda,
248 fortran_int *info);
249extern "C" fortran_int
250FNAME(zpotrf)(char *uplo, fortran_int *n,
251 f2c_doublecomplex a[], fortran_int *lda,
252 fortran_int *info);
253
254extern "C" fortran_int
255FNAME(sgesdd)(char *jobz, fortran_int *m, fortran_int *n,
256 float a[], fortran_int *lda, float s[], float u[],
257 fortran_int *ldu, float vt[], fortran_int *ldvt, float work[],
258 fortran_int *lwork, fortran_int iwork[], fortran_int *info);
259extern "C" fortran_int
260FNAME(dgesdd)(char *jobz, fortran_int *m, fortran_int *n,
261 double a[], fortran_int *lda, double s[], double u[],
262 fortran_int *ldu, double vt[], fortran_int *ldvt, double work[],
263 fortran_int *lwork, fortran_int iwork[], fortran_int *info);
264extern "C" fortran_int
265FNAME(cgesdd)(char *jobz, fortran_int *m, fortran_int *n,
266 f2c_complex a[], fortran_int *lda,
267 float s[], f2c_complex u[], fortran_int *ldu,
268 f2c_complex vt[], fortran_int *ldvt,
269 f2c_complex work[], fortran_int *lwork,
270 float rwork[], fortran_int iwork[], fortran_int *info);
271extern "C" fortran_int
272FNAME(zgesdd)(char *jobz, fortran_int *m, fortran_int *n,
273 f2c_doublecomplex a[], fortran_int *lda,
274 double s[], f2c_doublecomplex u[], fortran_int *ldu,
275 f2c_doublecomplex vt[], fortran_int *ldvt,
276 f2c_doublecomplex work[], fortran_int *lwork,
277 double rwork[], fortran_int iwork[], fortran_int *info);
278
279extern "C" fortran_int
280FNAME(spotrs)(char *uplo, fortran_int *n, fortran_int *nrhs,
281 float a[], fortran_int *lda,
282 float b[], fortran_int *ldb,
283 fortran_int *info);
284extern "C" fortran_int
285FNAME(dpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs,
286 double a[], fortran_int *lda,
287 double b[], fortran_int *ldb,
288 fortran_int *info);
289extern "C" fortran_int
290FNAME(cpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs,
291 f2c_complex a[], fortran_int *lda,
292 f2c_complex b[], fortran_int *ldb,
293 fortran_int *info);
294extern "C" fortran_int
295FNAME(zpotrs)(char *uplo, fortran_int *n, fortran_int *nrhs,
296 f2c_doublecomplex a[], fortran_int *lda,
297 f2c_doublecomplex b[], fortran_int *ldb,
298 fortran_int *info);
299
300extern "C" fortran_int
301FNAME(spotri)(char *uplo, fortran_int *n,
302 float a[], fortran_int *lda,
303 fortran_int *info);
304extern "C" fortran_int
305FNAME(dpotri)(char *uplo, fortran_int *n,
306 double a[], fortran_int *lda,
307 fortran_int *info);
308extern "C" fortran_int
309FNAME(cpotri)(char *uplo, fortran_int *n,
310 f2c_complex a[], fortran_int *lda,
311 fortran_int *info);
312extern "C" fortran_int
313FNAME(zpotri)(char *uplo, fortran_int *n,
314 f2c_doublecomplex a[], fortran_int *lda,
315 fortran_int *info);
316
317extern "C" fortran_int
318FNAME(scopy)(fortran_int *n,
319 float *sx, fortran_int *incx,
320 float *sy, fortran_int *incy);
321extern "C" fortran_int
322FNAME(dcopy)(fortran_int *n,
323 double *sx, fortran_int *incx,
324 double *sy, fortran_int *incy);
325extern "C" fortran_int
326FNAME(ccopy)(fortran_int *n,
327 f2c_complex *sx, fortran_int *incx,
328 f2c_complex *sy, fortran_int *incy);
329extern "C" fortran_int
330FNAME(zcopy)(fortran_int *n,
331 f2c_doublecomplex *sx, fortran_int *incx,
332 f2c_doublecomplex *sy, fortran_int *incy);
333
334extern "C" float
335FNAME(sdot)(fortran_int *n,
336 float *sx, fortran_int *incx,
337 float *sy, fortran_int *incy);
338extern "C" double
339FNAME(ddot)(fortran_int *n,
340 double *sx, fortran_int *incx,
341 double *sy, fortran_int *incy);
342extern "C" void
343FNAME(cdotu)(f2c_complex *ret, fortran_int *n,
344 f2c_complex *sx, fortran_int *incx,
345 f2c_complex *sy, fortran_int *incy);
346extern "C" void
347FNAME(zdotu)(f2c_doublecomplex *ret, fortran_int *n,
348 f2c_doublecomplex *sx, fortran_int *incx,
349 f2c_doublecomplex *sy, fortran_int *incy);
350extern "C" void
351FNAME(cdotc)(f2c_complex *ret, fortran_int *n,
352 f2c_complex *sx, fortran_int *incx,
353 f2c_complex *sy, fortran_int *incy);
354extern "C" void
355FNAME(zdotc)(f2c_doublecomplex *ret, fortran_int *n,
356 f2c_doublecomplex *sx, fortran_int *incx,
357 f2c_doublecomplex *sy, fortran_int *incy);
358
359extern "C" fortran_int
360FNAME(sgemm)(char *transa, char *transb,
361 fortran_int *m, fortran_int *n, fortran_int *k,
362 float *alpha,
363 float *a, fortran_int *lda,
364 float *b, fortran_int *ldb,
365 float *beta,
366 float *c, fortran_int *ldc);
367extern "C" fortran_int
368FNAME(dgemm)(char *transa, char *transb,
369 fortran_int *m, fortran_int *n, fortran_int *k,
370 double *alpha,
371 double *a, fortran_int *lda,
372 double *b, fortran_int *ldb,
373 double *beta,
374 double *c, fortran_int *ldc);
375extern "C" fortran_int
376FNAME(cgemm)(char *transa, char *transb,
377 fortran_int *m, fortran_int *n, fortran_int *k,
378 f2c_complex *alpha,
379 f2c_complex *a, fortran_int *lda,
380 f2c_complex *b, fortran_int *ldb,
381 f2c_complex *beta,
382 f2c_complex *c, fortran_int *ldc);
383extern "C" fortran_int
384FNAME(zgemm)(char *transa, char *transb,
385 fortran_int *m, fortran_int *n, fortran_int *k,
386 f2c_doublecomplex *alpha,
387 f2c_doublecomplex *a, fortran_int *lda,
388 f2c_doublecomplex *b, fortran_int *ldb,
389 f2c_doublecomplex *beta,
390 f2c_doublecomplex *c, fortran_int *ldc);
391
392
393#define LAPACK_T(FUNC) \
394 TRACE_TXT("Calling LAPACK ( " # FUNC " )\n"); \
395 FNAME(FUNC)
396
397#define BLAS(FUNC) \
398 FNAME(FUNC)
399
400#define LAPACK(FUNC) \
401 FNAME(FUNC)
402
403
404/*
405 *****************************************************************************
406 ** Some handy functions **
407 *****************************************************************************
408 */
409
410static inline int
411get_fp_invalid_and_clear(void)
412{
413 int status;
414 status = npy_clear_floatstatus_barrier((char*)&status);
415 return !!(status & NPY_FPE_INVALID);
416}
417
418static inline void
419set_fp_invalid_or_clear(int error_occurred)
420{
421 if (error_occurred) {
422 npy_set_floatstatus_invalid();
423 }
424 else {
425 npy_clear_floatstatus_barrier((char*)&error_occurred);
426 }
427}
428
429/*
430 *****************************************************************************
431 ** Some handy constants **
432 *****************************************************************************
433 */
434
435#define UMATH_LINALG_MODULE_NAME "_umath_linalg"
436
437template<typename T>
438struct numeric_limits;
439
440template<>
441struct numeric_limits<float> {
442static constexpr float one = 1.0f;
443static constexpr float zero = 0.0f;
444static constexpr float minus_one = -1.0f;
445static const float ninf;
446static const float nan;
447};
448constexpr float numeric_limits<float>::one;
449constexpr float numeric_limits<float>::zero;
450constexpr float numeric_limits<float>::minus_one;
451const float numeric_limits<float>::ninf = -NPY_INFINITYF;
452const float numeric_limits<float>::nan = NPY_NANF;
453
454template<>
455struct numeric_limits<double> {
456static constexpr double one = 1.0;
457static constexpr double zero = 0.0;
458static constexpr double minus_one = -1.0;
459static const double ninf;
460static const double nan;
461};
462constexpr double numeric_limits<double>::one;
463constexpr double numeric_limits<double>::zero;
464constexpr double numeric_limits<double>::minus_one;
465const double numeric_limits<double>::ninf = -NPY_INFINITY;
466const double numeric_limits<double>::nan = NPY_NAN;
467
468#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
469template<>
470struct numeric_limits<npy_cfloat> {
471static constexpr npy_cfloat one = {1.0f, 0.0f};
472static constexpr npy_cfloat zero = {0.0f, 0.0f};
473static constexpr npy_cfloat minus_one = {-1.0f, 0.0f};
474static const npy_cfloat ninf;
475static const npy_cfloat nan;
476};
477constexpr npy_cfloat numeric_limits<npy_cfloat>::one;
478constexpr npy_cfloat numeric_limits<npy_cfloat>::zero;
479constexpr npy_cfloat numeric_limits<npy_cfloat>::minus_one;
480const npy_cfloat numeric_limits<npy_cfloat>::ninf = {-NPY_INFINITYF, 0.0f};
481const npy_cfloat numeric_limits<npy_cfloat>::nan = {NPY_NANF, NPY_NANF};
482#else
483template<>
484struct numeric_limits<npy_cfloat> {
485static constexpr npy_cfloat one = 1.0f;
486static constexpr npy_cfloat zero = 0.0f;
487static constexpr npy_cfloat minus_one = -1.0f;
488static const npy_cfloat ninf;
489static const npy_cfloat nan;
490};
491constexpr npy_cfloat numeric_limits<npy_cfloat>::one;
492constexpr npy_cfloat numeric_limits<npy_cfloat>::zero;
493constexpr npy_cfloat numeric_limits<npy_cfloat>::minus_one;
494const npy_cfloat numeric_limits<npy_cfloat>::ninf = -NPY_INFINITYF;
495const npy_cfloat numeric_limits<npy_cfloat>::nan = NPY_NANF;
496#endif
497
498template<>
499struct numeric_limits<f2c_complex> {
500static constexpr f2c_complex one = {1.0f, 0.0f};
501static constexpr f2c_complex zero = {0.0f, 0.0f};
502static constexpr f2c_complex minus_one = {-1.0f, 0.0f};
503static const f2c_complex ninf;
504static const f2c_complex nan;
505};
506constexpr f2c_complex numeric_limits<f2c_complex>::one;
507constexpr f2c_complex numeric_limits<f2c_complex>::zero;
508constexpr f2c_complex numeric_limits<f2c_complex>::minus_one;
509const f2c_complex numeric_limits<f2c_complex>::ninf = {-NPY_INFINITYF, 0.0f};
510const f2c_complex numeric_limits<f2c_complex>::nan = {NPY_NANF, NPY_NANF};
511
512#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
513template<>
514struct numeric_limits<npy_cdouble> {
515static constexpr npy_cdouble one = {1.0, 0.0};
516static constexpr npy_cdouble zero = {0.0, 0.0};
517static constexpr npy_cdouble minus_one = {-1.0, 0.0};
518static const npy_cdouble ninf;
519static const npy_cdouble nan;
520};
521constexpr npy_cdouble numeric_limits<npy_cdouble>::one;
522constexpr npy_cdouble numeric_limits<npy_cdouble>::zero;
523constexpr npy_cdouble numeric_limits<npy_cdouble>::minus_one;
524const npy_cdouble numeric_limits<npy_cdouble>::ninf = {-NPY_INFINITY, 0.0};
525const npy_cdouble numeric_limits<npy_cdouble>::nan = {NPY_NAN, NPY_NAN};
526#else
527template<>
528struct numeric_limits<npy_cdouble> {
529static constexpr npy_cdouble one = 1.0;
530static constexpr npy_cdouble zero = 0.0;
531static constexpr npy_cdouble minus_one = -1.0;
532static const npy_cdouble ninf;
533static const npy_cdouble nan;
534};
535constexpr npy_cdouble numeric_limits<npy_cdouble>::one;
536constexpr npy_cdouble numeric_limits<npy_cdouble>::zero;
537constexpr npy_cdouble numeric_limits<npy_cdouble>::minus_one;
538const npy_cdouble numeric_limits<npy_cdouble>::ninf = -NPY_INFINITY;
539const npy_cdouble numeric_limits<npy_cdouble>::nan = NPY_NAN;
540#endif
541
542template<>
543struct numeric_limits<f2c_doublecomplex> {
544static constexpr f2c_doublecomplex one = {1.0, 0.0};
545static constexpr f2c_doublecomplex zero = {0.0, 0.0};
546static constexpr f2c_doublecomplex minus_one = {-1.0, 0.0};
547static const f2c_doublecomplex ninf;
548static const f2c_doublecomplex nan;
549};
550constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::one;
551constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::zero;
552constexpr f2c_doublecomplex numeric_limits<f2c_doublecomplex>::minus_one;
553const f2c_doublecomplex numeric_limits<f2c_doublecomplex>::ninf = {-NPY_INFINITY, 0.0};
554const f2c_doublecomplex numeric_limits<f2c_doublecomplex>::nan = {NPY_NAN, NPY_NAN};
555
556/*
557 *****************************************************************************
558 ** Structs used for data rearrangement **
559 *****************************************************************************
560 */
561
562
563/*
564 * this struct contains information about how to linearize a matrix in a local
565 * buffer so that it can be used by blas functions. All strides are specified
566 * in bytes and are converted to elements later in type specific functions.
567 *
568 * rows: number of rows in the matrix
569 * columns: number of columns in the matrix
570 * row_strides: the number bytes between consecutive rows.
571 * column_strides: the number of bytes between consecutive columns.
572 * output_lead_dim: BLAS/LAPACK-side leading dimension, in elements
573 */
574typedef struct linearize_data_struct
575{
576 npy_intp rows;
577 npy_intp columns;
578 npy_intp row_strides;
579 npy_intp column_strides;
580 npy_intp output_lead_dim;
581} LINEARIZE_DATA_t;
582
583static inline void
584init_linearize_data_ex(LINEARIZE_DATA_t *lin_data,
585 npy_intp rows,
586 npy_intp columns,
587 npy_intp row_strides,
588 npy_intp column_strides,
589 npy_intp output_lead_dim)
590{
591 lin_data->rows = rows;
592 lin_data->columns = columns;
593 lin_data->row_strides = row_strides;
594 lin_data->column_strides = column_strides;
595 lin_data->output_lead_dim = output_lead_dim;
596}
597
598static inline void
599init_linearize_data(LINEARIZE_DATA_t *lin_data,
600 npy_intp rows,
601 npy_intp columns,
602 npy_intp row_strides,
603 npy_intp column_strides)
604{
605 init_linearize_data_ex(
606 lin_data, rows, columns, row_strides, column_strides, columns);
607}
608
609static inline void
610dump_ufunc_object(PyUFuncObject* ufunc)
611{
612 TRACE_TXT("\n\n%s '%s' (%d input(s), %d output(s), %d specialization(s).\n",
613 ufunc->core_enabled? "generalized ufunc" : "scalar ufunc",
614 ufunc->name, ufunc->nin, ufunc->nout, ufunc->ntypes);
615 if (ufunc->core_enabled) {
616 int arg;
617 int dim;
618 TRACE_TXT("\t%s (%d dimension(s) detected).\n",
619 ufunc->core_signature, ufunc->core_num_dim_ix);
620
621 for (arg = 0; arg < ufunc->nargs; arg++){
622 int * arg_dim_ix = ufunc->core_dim_ixs + ufunc->core_offsets[arg];
623 TRACE_TXT("\t\targ %d (%s) has %d dimension(s): (",
624 arg, arg < ufunc->nin? "INPUT" : "OUTPUT",
625 ufunc->core_num_dims[arg]);
626 for (dim = 0; dim < ufunc->core_num_dims[arg]; dim ++) {
627 TRACE_TXT(" %d", arg_dim_ix[dim]);
628 }
629 TRACE_TXT(" )\n");
630 }
631 }
632}
633
634static inline void
635dump_linearize_data(const char* name, const LINEARIZE_DATA_t* params)
636{
637 TRACE_TXT("\n\t%s rows: %zd columns: %zd"\
638 "\n\t\trow_strides: %td column_strides: %td"\
639 "\n", name, params->rows, params->columns,
640 params->row_strides, params->column_strides);
641}
642
643static inline void
644print(npy_float s)
645{
646 TRACE_TXT(" %8.4f", s);
647}
648static inline void
649print(npy_double d)
650{
651 TRACE_TXT(" %10.6f", d);
652}
653static inline void
654print(npy_cfloat c)
655{
656 float* c_parts = (float*)&c;
657 TRACE_TXT("(%8.4f, %8.4fj)", c_parts[0], c_parts[1]);
658}
659static inline void
660print(npy_cdouble z)
661{
662 double* z_parts = (double*)&z;
663 TRACE_TXT("(%8.4f, %8.4fj)", z_parts[0], z_parts[1]);
664}
665
666template<typename typ>
667static inline void
668dump_matrix(const char* name,
669 size_t rows, size_t columns,
670 const typ* ptr)
671{
672 size_t i, j;
673
674 TRACE_TXT("\n%s %p (%zd, %zd)\n", name, ptr, rows, columns);
675 for (i = 0; i < rows; i++)
676 {
677 TRACE_TXT("| ");
678 for (j = 0; j < columns; j++)
679 {
680 print(ptr[j*rows + i]);
681 TRACE_TXT(", ");
682 }
683 TRACE_TXT(" |\n");
684 }
685}
686
687
688/*
689 *****************************************************************************
690 ** Basics **
691 *****************************************************************************
692 */
693
694static inline fortran_int
695fortran_int_min(fortran_int x, fortran_int y) {
696 return x < y ? x : y;
697}
698
699static inline fortran_int
700fortran_int_max(fortran_int x, fortran_int y) {
701 return x > y ? x : y;
702}
703
704#define INIT_OUTER_LOOP_1 \
705 npy_intp dN = *dimensions++;\
706 npy_intp N_;\
707 npy_intp s0 = *steps++;
708
709#define INIT_OUTER_LOOP_2 \
710 INIT_OUTER_LOOP_1\
711 npy_intp s1 = *steps++;
712
713#define INIT_OUTER_LOOP_3 \
714 INIT_OUTER_LOOP_2\
715 npy_intp s2 = *steps++;
716
717#define INIT_OUTER_LOOP_4 \
718 INIT_OUTER_LOOP_3\
719 npy_intp s3 = *steps++;
720
721#define INIT_OUTER_LOOP_5 \
722 INIT_OUTER_LOOP_4\
723 npy_intp s4 = *steps++;
724
725#define INIT_OUTER_LOOP_6 \
726 INIT_OUTER_LOOP_5\
727 npy_intp s5 = *steps++;
728
729#define INIT_OUTER_LOOP_7 \
730 INIT_OUTER_LOOP_6\
731 npy_intp s6 = *steps++;
732
733#define BEGIN_OUTER_LOOP_2 \
734 for (N_ = 0;\
735 N_ < dN;\
736 N_++, args[0] += s0,\
737 args[1] += s1) {
738
739#define BEGIN_OUTER_LOOP_3 \
740 for (N_ = 0;\
741 N_ < dN;\
742 N_++, args[0] += s0,\
743 args[1] += s1,\
744 args[2] += s2) {
745
746#define BEGIN_OUTER_LOOP_4 \
747 for (N_ = 0;\
748 N_ < dN;\
749 N_++, args[0] += s0,\
750 args[1] += s1,\
751 args[2] += s2,\
752 args[3] += s3) {
753
754#define BEGIN_OUTER_LOOP_5 \
755 for (N_ = 0;\
756 N_ < dN;\
757 N_++, args[0] += s0,\
758 args[1] += s1,\
759 args[2] += s2,\
760 args[3] += s3,\
761 args[4] += s4) {
762
763#define BEGIN_OUTER_LOOP_6 \
764 for (N_ = 0;\
765 N_ < dN;\
766 N_++, args[0] += s0,\
767 args[1] += s1,\
768 args[2] += s2,\
769 args[3] += s3,\
770 args[4] += s4,\
771 args[5] += s5) {
772
773#define BEGIN_OUTER_LOOP_7 \
774 for (N_ = 0;\
775 N_ < dN;\
776 N_++, args[0] += s0,\
777 args[1] += s1,\
778 args[2] += s2,\
779 args[3] += s3,\
780 args[4] += s4,\
781 args[5] += s5,\
782 args[6] += s6) {
783
784#define END_OUTER_LOOP }
785
786static inline void
787update_pointers(npy_uint8** bases, ptrdiff_t* offsets, size_t count)
788{
789 size_t i;
790 for (i = 0; i < count; ++i) {
791 bases[i] += offsets[i];
792 }
793}
794
795
796/* disable -Wmaybe-uninitialized as there is some code that generate false
797 positives with this warning
798*/
799#pragma GCC diagnostic push
800#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
801
802/*
803 *****************************************************************************
804 ** DISPATCHER FUNCS **
805 *****************************************************************************
806 */
807static fortran_int copy(fortran_int *n,
808 float *sx, fortran_int *incx,
809 float *sy, fortran_int *incy) { return FNAME(scopy)(n, sx, incx,
810 sy, incy);
811}
812static fortran_int copy(fortran_int *n,
813 double *sx, fortran_int *incx,
814 double *sy, fortran_int *incy) { return FNAME(dcopy)(n, sx, incx,
815 sy, incy);
816}
817static fortran_int copy(fortran_int *n,
818 f2c_complex *sx, fortran_int *incx,
819 f2c_complex *sy, fortran_int *incy) { return FNAME(ccopy)(n, sx, incx,
820 sy, incy);
821}
822static fortran_int copy(fortran_int *n,
823 f2c_doublecomplex *sx, fortran_int *incx,
824 f2c_doublecomplex *sy, fortran_int *incy) { return FNAME(zcopy)(n, sx, incx,
825 sy, incy);
826}
827
828static fortran_int getrf(fortran_int *m, fortran_int *n, float a[], fortran_int
829*lda, fortran_int ipiv[], fortran_int *info) {
830 return LAPACK(sgetrf)(m, n, a, lda, ipiv, info);
831}
832static fortran_int getrf(fortran_int *m, fortran_int *n, double a[], fortran_int
833*lda, fortran_int ipiv[], fortran_int *info) {
834 return LAPACK(dgetrf)(m, n, a, lda, ipiv, info);
835}
836static fortran_int getrf(fortran_int *m, fortran_int *n, f2c_complex a[], fortran_int
837*lda, fortran_int ipiv[], fortran_int *info) {
838 return LAPACK(cgetrf)(m, n, a, lda, ipiv, info);
839}
840static fortran_int getrf(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int
841*lda, fortran_int ipiv[], fortran_int *info) {
842 return LAPACK(zgetrf)(m, n, a, lda, ipiv, info);
843}
844
845/*
846 *****************************************************************************
847 ** HELPER FUNCS **
848 *****************************************************************************
849 */
850template<typename T>
851struct fortran_type {
852using type = T;
853};
854
855template<> struct fortran_type<npy_cfloat> { using type = f2c_complex;};
856template<> struct fortran_type<npy_cdouble> { using type = f2c_doublecomplex;};
857template<typename T>
858using fortran_type_t = typename fortran_type<T>::type;
859
860template<typename T>
861struct basetype {
862using type = T;
863};
864template<> struct basetype<npy_cfloat> { using type = npy_float;};
865template<> struct basetype<npy_cdouble> { using type = npy_double;};
866template<> struct basetype<f2c_complex> { using type = fortran_real;};
867template<> struct basetype<f2c_doublecomplex> { using type = fortran_doublereal;};
868template<typename T>
869using basetype_t = typename basetype<T>::type;
870
871struct scalar_trait {};
872struct complex_trait {};
873template<typename typ>
874using dispatch_scalar = typename std::conditional<sizeof(basetype_t<typ>) == sizeof(typ), scalar_trait, complex_trait>::type;
875
876
877 /* rearranging of 2D matrices using blas */
878
879template<typename typ>
880static inline void *
881linearize_matrix(typ *dst,
882 typ *src,
883 const LINEARIZE_DATA_t* data)
884{
885 using ftyp = fortran_type_t<typ>;
886 if (dst) {
887 int i, j;
888 typ* rv = dst;
889 fortran_int columns = (fortran_int)data->columns;
890 fortran_int column_strides =
891 (fortran_int)(data->column_strides/sizeof(typ));
892 fortran_int one = 1;
893 for (i = 0; i < data->rows; i++) {
894 if (column_strides > 0) {
895 copy(&columns,
896 (ftyp*)src, &column_strides,
897 (ftyp*)dst, &one);
898 }
899 else if (column_strides < 0) {
900 copy(&columns,
901 ((ftyp*)src + (columns-1)*column_strides),
902 &column_strides,
903 (ftyp*)dst, &one);
904 }
905 else {
906 /*
907 * Zero stride has undefined behavior in some BLAS
908 * implementations (e.g. OSX Accelerate), so do it
909 * manually
910 */
911 for (j = 0; j < columns; ++j) {
912 memcpy(dst + j, src, sizeof(typ));
913 }
914 }
915 src += data->row_strides/sizeof(typ);
916 dst += data->output_lead_dim;
917 }
918 return rv;
919 } else {
920 return src;
921 }
922}
923
924template<typename typ>
925static inline void *
926delinearize_matrix(typ *dst,
927 typ *src,
928 const LINEARIZE_DATA_t* data)
929{
930using ftyp = fortran_type_t<typ>;
931
932 if (src) {
933 int i;
934 typ *rv = src;
935 fortran_int columns = (fortran_int)data->columns;
936 fortran_int column_strides =
937 (fortran_int)(data->column_strides/sizeof(typ));
938 fortran_int one = 1;
939 for (i = 0; i < data->rows; i++) {
940 if (column_strides > 0) {
941 copy(&columns,
942 (ftyp*)src, &one,
943 (ftyp*)dst, &column_strides);
944 }
945 else if (column_strides < 0) {
946 copy(&columns,
947 (ftyp*)src, &one,
948 ((ftyp*)dst + (columns-1)*column_strides),
949 &column_strides);
950 }
951 else {
952 /*
953 * Zero stride has undefined behavior in some BLAS
954 * implementations (e.g. OSX Accelerate), so do it
955 * manually
956 */
957 if (columns > 0) {
958 memcpy(dst,
959 src + (columns-1),
960 sizeof(typ));
961 }
962 }
963 src += data->output_lead_dim;
964 dst += data->row_strides/sizeof(typ);
965 }
966
967 return rv;
968 } else {
969 return src;
970 }
971}
972
973template<typename typ>
974static inline void
975nan_matrix(typ *dst, const LINEARIZE_DATA_t* data)
976{
977 int i, j;
978 for (i = 0; i < data->rows; i++) {
979 typ *cp = dst;
980 ptrdiff_t cs = data->column_strides/sizeof(typ);
981 for (j = 0; j < data->columns; ++j) {
982 *cp = numeric_limits<typ>::nan;
983 cp += cs;
984 }
985 dst += data->row_strides/sizeof(typ);
986 }
987}
988
989template<typename typ>
990static inline void
991zero_matrix(typ *dst, const LINEARIZE_DATA_t* data)
992{
993 int i, j;
994 for (i = 0; i < data->rows; i++) {
995 typ *cp = dst;
996 ptrdiff_t cs = data->column_strides/sizeof(typ);
997 for (j = 0; j < data->columns; ++j) {
998 *cp = numeric_limits<typ>::zero;
999 cp += cs;
1000 }
1001 dst += data->row_strides/sizeof(typ);
1002 }
1003}
1004
1005 /* identity square matrix generation */
1006template<typename typ>
1007static inline void
1008identity_matrix(typ *matrix, size_t n)
1009{
1010 size_t i;
1011 /* in IEEE floating point, zeroes are represented as bitwise 0 */
1012 memset((void *)matrix, 0, n*n*sizeof(typ));
1013
1014 for (i = 0; i < n; ++i)
1015 {
1016 *matrix = numeric_limits<typ>::one;
1017 matrix += n+1;
1018 }
1019}
1020
1021 /* lower/upper triangular matrix using blas (in place) */
1022
1023template<typename typ>
1024static inline void
1025triu_matrix(typ *matrix, size_t n)
1026{
1027 size_t i, j;
1028 matrix += n;
1029 for (i = 1; i < n; ++i) {
1030 for (j = 0; j < i; ++j) {
1031 matrix[j] = numeric_limits<typ>::zero;
1032 }
1033 matrix += n;
1034 }
1035}
1036
1037
1038/* -------------------------------------------------------------------------- */
1039 /* Determinants */
1040
1041static npy_float npylog(npy_float f) { return npy_logf(f);}
1042static npy_double npylog(npy_double d) { return npy_log(d);}
1043static npy_float npyexp(npy_float f) { return npy_expf(f);}
1044static npy_double npyexp(npy_double d) { return npy_exp(d);}
1045
1046template<typename typ>
1047static inline void
1048slogdet_from_factored_diagonal(typ* src,
1049 fortran_int m,
1050 typ *sign,
1051 typ *logdet)
1052{
1053 typ acc_sign = *sign;
1054 typ acc_logdet = numeric_limits<typ>::zero;
1055 int i;
1056 for (i = 0; i < m; i++) {
1057 typ abs_element = *src;
1058 if (abs_element < numeric_limits<typ>::zero) {
1059 acc_sign = -acc_sign;
1060 abs_element = -abs_element;
1061 }
1062
1063 acc_logdet += npylog(abs_element);
1064 src += m+1;
1065 }
1066
1067 *sign = acc_sign;
1068 *logdet = acc_logdet;
1069}
1070
1071template<typename typ>
1072static inline typ
1073det_from_slogdet(typ sign, typ logdet)
1074{
1075 typ result = sign * npyexp(logdet);
1076 return result;
1077}
1078
1079
1080npy_float npyabs(npy_cfloat z) { return npy_cabsf(z);}
1081npy_double npyabs(npy_cdouble z) { return npy_cabs(z);}
1082
1083inline float RE(npy_cfloat *c) { return npy_crealf(*c); }
1084inline double RE(npy_cdouble *c) { return npy_creal(*c); }
1085#if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE
1086inline longdouble_t RE(npy_clongdouble *c) { return npy_creall(*c); }
1087#endif
1088inline float IM(npy_cfloat *c) { return npy_cimagf(*c); }
1089inline double IM(npy_cdouble *c) { return npy_cimag(*c); }
1090#if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE
1091inline longdouble_t IM(npy_clongdouble *c) { return npy_cimagl(*c); }
1092#endif
1093inline void SETRE(npy_cfloat *c, float real) { npy_csetrealf(c, real); }
1094inline void SETRE(npy_cdouble *c, double real) { npy_csetreal(c, real); }
1095#if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE
1096inline void SETRE(npy_clongdouble *c, double real) { npy_csetreall(c, real); }
1097#endif
1098inline void SETIM(npy_cfloat *c, float real) { npy_csetimagf(c, real); }
1099inline void SETIM(npy_cdouble *c, double real) { npy_csetimag(c, real); }
1100#if NPY_SIZEOF_COMPLEX_LONGDOUBLE != NPY_SIZEOF_COMPLEX_DOUBLE
1101inline void SETIM(npy_clongdouble *c, double real) { npy_csetimagl(c, real); }
1102#endif
1103
1104template<typename typ>
1105static inline typ
1106mult(typ op1, typ op2)
1107{
1108 typ rv;
1109
1110 SETRE(&rv, RE(&op1)*RE(&op2) - IM(&op1)*IM(&op2));
1111 SETIM(&rv, RE(&op1)*IM(&op2) + IM(&op1)*RE(&op2));
1112
1113 return rv;
1114}
1115
1116
1117template<typename typ, typename basetyp>
1118static inline void
1119slogdet_from_factored_diagonal(typ* src,
1120 fortran_int m,
1121 typ *sign,
1122 basetyp *logdet)
1123{
1124 int i;
1125 typ sign_acc = *sign;
1126 basetyp logdet_acc = numeric_limits<basetyp>::zero;
1127
1128 for (i = 0; i < m; i++)
1129 {
1130 basetyp abs_element = npyabs(*src);
1131 typ sign_element;
1132 SETRE(&sign_element, RE(src) / abs_element);
1133 SETIM(&sign_element, IM(src) / abs_element);
1134
1135 sign_acc = mult(sign_acc, sign_element);
1136 logdet_acc += npylog(abs_element);
1137 src += m + 1;
1138 }
1139
1140 *sign = sign_acc;
1141 *logdet = logdet_acc;
1142}
1143
1144template<typename typ, typename basetyp>
1145static inline typ
1146det_from_slogdet(typ sign, basetyp logdet)
1147{
1148 typ tmp;
1149 SETRE(&tmp, npyexp(logdet));
1150 SETIM(&tmp, numeric_limits<basetyp>::zero);
1151 return mult(sign, tmp);
1152}
1153
1154
1155/* As in the linalg package, the determinant is computed via LU factorization
1156 * using LAPACK.
1157 * slogdet computes sign + log(determinant).
1158 * det computes sign * exp(slogdet).
1159 */
1160template<typename typ, typename basetyp>
1161static inline void
1162slogdet_single_element(fortran_int m,
1163 typ* src,
1164 fortran_int* pivots,
1165 typ *sign,
1166 basetyp *logdet)
1167{
1168using ftyp = fortran_type_t<typ>;
1169 fortran_int info = 0;
1170 fortran_int lda = fortran_int_max(m, 1);
1171 int i;
1172 /* note: done in place */
1173 getrf(&m, &m, (ftyp*)src, &lda, pivots, &info);
1174
1175 if (info == 0) {
1176 int change_sign = 0;
1177 /* note: fortran uses 1 based indexing */
1178 for (i = 0; i < m; i++)
1179 {
1180 change_sign += (pivots[i] != (i+1));
1181 }
1182
1183 *sign = (change_sign % 2)?numeric_limits<typ>::minus_one:numeric_limits<typ>::one;
1184 slogdet_from_factored_diagonal(src, m, sign, logdet);
1185 } else {
1186 /*
1187 if getrf fails, use 0 as sign and -inf as logdet
1188 */
1189 *sign = numeric_limits<typ>::zero;
1190 *logdet = numeric_limits<basetyp>::ninf;
1191 }
1192}
1193
1194template<typename typ, typename basetyp>
1195static void
1196slogdet(char **args,
1197 npy_intp const *dimensions,
1198 npy_intp const *steps,
1199 void *NPY_UNUSED(func))
1200{
1201 fortran_int m;
1202 char *tmp_buff = NULL;
1203 size_t matrix_size;
1204 size_t pivot_size;
1205 size_t safe_m;
1206 /* notes:
1207 * matrix will need to be copied always, as factorization in lapack is
1208 * made inplace
1209 * matrix will need to be in column-major order, as expected by lapack
1210 * code (fortran)
1211 * always a square matrix
1212 * need to allocate memory for both, matrix_buffer and pivot buffer
1213 */
1214 INIT_OUTER_LOOP_3
1215 m = (fortran_int) dimensions[0];
1216 /* avoid empty malloc (buffers likely unused) and ensure m is `size_t` */
1217 safe_m = m != 0 ? m : 1;
1218 matrix_size = safe_m * safe_m * sizeof(typ);
1219 pivot_size = safe_m * sizeof(fortran_int);
1220 tmp_buff = (char *)malloc(matrix_size + pivot_size);
1221
1222 if (tmp_buff) {
1223 LINEARIZE_DATA_t lin_data;
1224 /* swapped steps to get matrix in FORTRAN order */
1225 init_linearize_data(&lin_data, m, m, steps[1], steps[0]);
1226 BEGIN_OUTER_LOOP_3
1227 linearize_matrix((typ*)tmp_buff, (typ*)args[0], &lin_data);
1228 slogdet_single_element(m,
1229 (typ*)tmp_buff,
1230 (fortran_int*)(tmp_buff+matrix_size),
1231 (typ*)args[1],
1232 (basetyp*)args[2]);
1233 END_OUTER_LOOP
1234
1235 free(tmp_buff);
1236 }
1237 else {
1238 /* TODO: Requires use of new ufunc API to indicate error return */
1239 NPY_ALLOW_C_API_DEF
1240 NPY_ALLOW_C_API;
1241 PyErr_NoMemory();
1242 NPY_DISABLE_C_API;
1243 }
1244}
1245
1246template<typename typ, typename basetyp>
1247static void
1248det(char **args,
1249 npy_intp const *dimensions,
1250 npy_intp const *steps,
1251 void *NPY_UNUSED(func))
1252{
1253 fortran_int m;
1254 char *tmp_buff;
1255 size_t matrix_size;
1256 size_t pivot_size;
1257 size_t safe_m;
1258 /* notes:
1259 * matrix will need to be copied always, as factorization in lapack is
1260 * made inplace
1261 * matrix will need to be in column-major order, as expected by lapack
1262 * code (fortran)
1263 * always a square matrix
1264 * need to allocate memory for both, matrix_buffer and pivot buffer
1265 */
1266 INIT_OUTER_LOOP_2
1267 m = (fortran_int) dimensions[0];
1268 /* avoid empty malloc (buffers likely unused) and ensure m is `size_t` */
1269 safe_m = m != 0 ? m : 1;
1270 matrix_size = safe_m * safe_m * sizeof(typ);
1271 pivot_size = safe_m * sizeof(fortran_int);
1272 tmp_buff = (char *)malloc(matrix_size + pivot_size);
1273
1274 if (tmp_buff) {
1275 LINEARIZE_DATA_t lin_data;
1276 typ sign;
1277 basetyp logdet;
1278 /* swapped steps to get matrix in FORTRAN order */
1279 init_linearize_data(&lin_data, m, m, steps[1], steps[0]);
1280
1281 BEGIN_OUTER_LOOP_2
1282 linearize_matrix((typ*)tmp_buff, (typ*)args[0], &lin_data);
1283 slogdet_single_element(m,
1284 (typ*)tmp_buff,
1285 (fortran_int*)(tmp_buff + matrix_size),
1286 &sign,
1287 &logdet);
1288 *(typ *)args[1] = det_from_slogdet(sign, logdet);
1289 END_OUTER_LOOP
1290
1291 free(tmp_buff);
1292 }
1293 else {
1294 /* TODO: Requires use of new ufunc API to indicate error return */
1295 NPY_ALLOW_C_API_DEF
1296 NPY_ALLOW_C_API;
1297 PyErr_NoMemory();
1298 NPY_DISABLE_C_API;
1299 }
1300}
1301
1302
1303/* -------------------------------------------------------------------------- */
1304 /* Eigh family */
1305
1306template<typename typ>
1307struct EIGH_PARAMS_t {
1308 typ *A; /* matrix */
1309 basetype_t<typ> *W; /* eigenvalue vector */
1310 typ *WORK; /* main work buffer */
1311 basetype_t<typ> *RWORK; /* secondary work buffer (for complex versions) */
1312 fortran_int *IWORK;
1313 fortran_int N;
1314 fortran_int LWORK;
1315 fortran_int LRWORK;
1316 fortran_int LIWORK;
1317 char JOBZ;
1318 char UPLO;
1319 fortran_int LDA;
1320} ;
1321
1322static inline fortran_int
1323call_evd(EIGH_PARAMS_t<npy_float> *params)
1324{
1325 fortran_int rv;
1326 LAPACK(ssyevd)(¶ms->JOBZ, ¶ms->UPLO, ¶ms->N,
1327 params->A, ¶ms->LDA, params->W,
1328 params->WORK, ¶ms->LWORK,
1329 params->IWORK, ¶ms->LIWORK,
1330 &rv);
1331 return rv;
1332}
1333static inline fortran_int
1334call_evd(EIGH_PARAMS_t<npy_double> *params)
1335{
1336 fortran_int rv;
1337 LAPACK(dsyevd)(¶ms->JOBZ, ¶ms->UPLO, ¶ms->N,
1338 params->A, ¶ms->LDA, params->W,
1339 params->WORK, ¶ms->LWORK,
1340 params->IWORK, ¶ms->LIWORK,
1341 &rv);
1342 return rv;
1343}
1344
1345
1346/*
1347 * Initialize the parameters to use in for the lapack function _syevd
1348 * Handles buffer allocation
1349 */
1350template<typename typ>
1351static inline int
1352init_evd(EIGH_PARAMS_t<typ>* params, char JOBZ, char UPLO,
1353 fortran_int N, scalar_trait)
1354{
1355 npy_uint8 *mem_buff = NULL;
1356 npy_uint8 *mem_buff2 = NULL;
1357 fortran_int lwork;
1358 fortran_int liwork;
1359 npy_uint8 *a, *w, *work, *iwork;
1360 size_t safe_N = N;
1361 size_t alloc_size = safe_N * (safe_N + 1) * sizeof(typ);
1362 fortran_int lda = fortran_int_max(N, 1);
1363
1364 mem_buff = (npy_uint8 *)malloc(alloc_size);
1365
1366 if (!mem_buff) {
1367 goto error;
1368 }
1369 a = mem_buff;
1370 w = mem_buff + safe_N * safe_N * sizeof(typ);
1371
1372 params->A = (typ*)a;
1373 params->W = (typ*)w;
1374 params->RWORK = NULL; /* unused */
1375 params->N = N;
1376 params->LRWORK = 0; /* unused */
1377 params->JOBZ = JOBZ;
1378 params->UPLO = UPLO;
1379 params->LDA = lda;
1380
1381 /* Work size query */
1382 {
1383 typ query_work_size;
1384 fortran_int query_iwork_size;
1385
1386 params->LWORK = -1;
1387 params->LIWORK = -1;
1388 params->WORK = &query_work_size;
1389 params->IWORK = &query_iwork_size;
1390
1391 if (call_evd(params) != 0) {
1392 goto error;
1393 }
1394
1395 lwork = (fortran_int)query_work_size;
1396 liwork = query_iwork_size;
1397 }
1398
1399 mem_buff2 = (npy_uint8 *)malloc(lwork*sizeof(typ) + liwork*sizeof(fortran_int));
1400 if (!mem_buff2) {
1401 goto error;
1402 }
1403
1404 work = mem_buff2;
1405 iwork = mem_buff2 + lwork*sizeof(typ);
1406
1407 params->LWORK = lwork;
1408 params->WORK = (typ*)work;
1409 params->LIWORK = liwork;
1410 params->IWORK = (fortran_int*)iwork;
1411
1412 return 1;
1413
1414 error:
1415 /* something failed */
1416 memset(params, 0, sizeof(*params));
1417 free(mem_buff2);
1418 free(mem_buff);
1419
1420 return 0;
1421}
1422
1423
1424static inline fortran_int
1425call_evd(EIGH_PARAMS_t<npy_cfloat> *params)
1426{
1427 fortran_int rv;
1428 LAPACK(cheevd)(¶ms->JOBZ, ¶ms->UPLO, ¶ms->N,
1429 (fortran_type_t<npy_cfloat>*)params->A, ¶ms->LDA, params->W,
1430 (fortran_type_t<npy_cfloat>*)params->WORK, ¶ms->LWORK,
1431 params->RWORK, ¶ms->LRWORK,
1432 params->IWORK, ¶ms->LIWORK,
1433 &rv);
1434 return rv;
1435}
1436
1437static inline fortran_int
1438call_evd(EIGH_PARAMS_t<npy_cdouble> *params)
1439{
1440 fortran_int rv;
1441 LAPACK(zheevd)(¶ms->JOBZ, ¶ms->UPLO, ¶ms->N,
1442 (fortran_type_t<npy_cdouble>*)params->A, ¶ms->LDA, params->W,
1443 (fortran_type_t<npy_cdouble>*)params->WORK, ¶ms->LWORK,
1444 params->RWORK, ¶ms->LRWORK,
1445 params->IWORK, ¶ms->LIWORK,
1446 &rv);
1447 return rv;
1448}
1449
1450template<typename typ>
1451static inline int
1452init_evd(EIGH_PARAMS_t<typ> *params,
1453 char JOBZ,
1454 char UPLO,
1455 fortran_int N, complex_trait)
1456{
1457 using basetyp = basetype_t<typ>;
1458using ftyp = fortran_type_t<typ>;
1459using fbasetyp = fortran_type_t<basetyp>;
1460 npy_uint8 *mem_buff = NULL;
1461 npy_uint8 *mem_buff2 = NULL;
1462 fortran_int lwork;
1463 fortran_int lrwork;
1464 fortran_int liwork;
1465 npy_uint8 *a, *w, *work, *rwork, *iwork;
1466 size_t safe_N = N;
1467 fortran_int lda = fortran_int_max(N, 1);
1468
1469 mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(typ) +
1470 safe_N * sizeof(basetyp));
1471 if (!mem_buff) {
1472 goto error;
1473 }
1474 a = mem_buff;
1475 w = mem_buff + safe_N * safe_N * sizeof(typ);
1476
1477 params->A = (typ*)a;
1478 params->W = (basetyp*)w;
1479 params->N = N;
1480 params->JOBZ = JOBZ;
1481 params->UPLO = UPLO;
1482 params->LDA = lda;
1483
1484 /* Work size query */
1485 {
1486 ftyp query_work_size;
1487 fbasetyp query_rwork_size;
1488 fortran_int query_iwork_size;
1489
1490 params->LWORK = -1;
1491 params->LRWORK = -1;
1492 params->LIWORK = -1;
1493 params->WORK = (typ*)&query_work_size;
1494 params->RWORK = (basetyp*)&query_rwork_size;
1495 params->IWORK = &query_iwork_size;
1496
1497 if (call_evd(params) != 0) {
1498 goto error;
1499 }
1500
1501 lwork = (fortran_int)*(fbasetyp*)&query_work_size;
1502 lrwork = (fortran_int)query_rwork_size;
1503 liwork = query_iwork_size;
1504 }
1505
1506 mem_buff2 = (npy_uint8 *)malloc(lwork*sizeof(typ) +
1507 lrwork*sizeof(basetyp) +
1508 liwork*sizeof(fortran_int));
1509 if (!mem_buff2) {
1510 goto error;
1511 }
1512
1513 work = mem_buff2;
1514 rwork = work + lwork*sizeof(typ);
1515 iwork = rwork + lrwork*sizeof(basetyp);
1516
1517 params->WORK = (typ*)work;
1518 params->RWORK = (basetyp*)rwork;
1519 params->IWORK = (fortran_int*)iwork;
1520 params->LWORK = lwork;
1521 params->LRWORK = lrwork;
1522 params->LIWORK = liwork;
1523
1524 return 1;
1525
1526 /* something failed */
1527error:
1528 memset(params, 0, sizeof(*params));
1529 free(mem_buff2);
1530 free(mem_buff);
1531
1532 return 0;
1533}
1534
1535/*
1536 * (M, M)->(M,)(M, M)
1537 * dimensions[1] -> M
1538 * args[0] -> A[in]
1539 * args[1] -> W
1540 * args[2] -> A[out]
1541 */
1542
1543template<typename typ>
1544static inline void
1545release_evd(EIGH_PARAMS_t<typ> *params)
1546{
1547 /* allocated memory in A and WORK */
1548 free(params->A);
1549 free(params->WORK);
1550 memset(params, 0, sizeof(*params));
1551}
1552
1553
1554template<typename typ>
1555static inline void
1556eigh_wrapper(char JOBZ,
1557 char UPLO,
1558 char**args,
1559 npy_intp const *dimensions,
1560 npy_intp const *steps)
1561{
1562 using basetyp = basetype_t<typ>;
1563 ptrdiff_t outer_steps[3];
1564 size_t iter;
1565 size_t outer_dim = *dimensions++;
1566 size_t op_count = (JOBZ=='N')?2:3;
1567 EIGH_PARAMS_t<typ> eigh_params;
1568 int error_occurred = get_fp_invalid_and_clear();
1569
1570 for (iter = 0; iter < op_count; ++iter) {
1571 outer_steps[iter] = (ptrdiff_t) steps[iter];
1572 }
1573 steps += op_count;
1574
1575 if (init_evd(&eigh_params,
1576 JOBZ,
1577 UPLO,
1578 (fortran_int)dimensions[0], dispatch_scalar<typ>())) {
1579 LINEARIZE_DATA_t matrix_in_ld;
1580 LINEARIZE_DATA_t eigenvectors_out_ld;
1581 LINEARIZE_DATA_t eigenvalues_out_ld;
1582
1583 init_linearize_data(&matrix_in_ld,
1584 eigh_params.N, eigh_params.N,
1585 steps[1], steps[0]);
1586 init_linearize_data(&eigenvalues_out_ld,
1587 1, eigh_params.N,
1588 0, steps[2]);
1589 if ('V' == eigh_params.JOBZ) {
1590 init_linearize_data(&eigenvectors_out_ld,
1591 eigh_params.N, eigh_params.N,
1592 steps[4], steps[3]);
1593 }
1594
1595 for (iter = 0; iter < outer_dim; ++iter) {
1596 int not_ok;
1597 /* copy the matrix in */
1598 linearize_matrix((typ*)eigh_params.A, (typ*)args[0], &matrix_in_ld);
1599 not_ok = call_evd(&eigh_params);
1600 if (!not_ok) {
1601 /* lapack ok, copy result out */
1602 delinearize_matrix((basetyp*)args[1],
1603 (basetyp*)eigh_params.W,
1604 &eigenvalues_out_ld);
1605
1606 if ('V' == eigh_params.JOBZ) {
1607 delinearize_matrix((typ*)args[2],
1608 (typ*)eigh_params.A,
1609 &eigenvectors_out_ld);
1610 }
1611 } else {
1612 /* lapack fail, set result to nan */
1613 error_occurred = 1;
1614 nan_matrix((basetyp*)args[1], &eigenvalues_out_ld);
1615 if ('V' == eigh_params.JOBZ) {
1616 nan_matrix((typ*)args[2], &eigenvectors_out_ld);
1617 }
1618 }
1619 update_pointers((npy_uint8**)args, outer_steps, op_count);
1620 }
1621
1622 release_evd(&eigh_params);
1623 }
1624
1625 set_fp_invalid_or_clear(error_occurred);
1626}
1627
1628
1629template<typename typ>
1630static void
1631eighlo(char **args,
1632 npy_intp const *dimensions,
1633 npy_intp const *steps,
1634 void *NPY_UNUSED(func))
1635{
1636 eigh_wrapper<typ>('V', 'L', args, dimensions, steps);
1637}
1638
1639template<typename typ>
1640static void
1641eighup(char **args,
1642 npy_intp const *dimensions,
1643 npy_intp const *steps,
1644 void* NPY_UNUSED(func))
1645{
1646 eigh_wrapper<typ>('V', 'U', args, dimensions, steps);
1647}
1648
1649template<typename typ>
1650static void
1651eigvalshlo(char **args,
1652 npy_intp const *dimensions,
1653 npy_intp const *steps,
1654 void* NPY_UNUSED(func))
1655{
1656 eigh_wrapper<typ>('N', 'L', args, dimensions, steps);
1657}
1658
1659template<typename typ>
1660static void
1661eigvalshup(char **args,
1662 npy_intp const *dimensions,
1663 npy_intp const *steps,
1664 void* NPY_UNUSED(func))
1665{
1666 eigh_wrapper<typ>('N', 'U', args, dimensions, steps);
1667}
1668
1669/* -------------------------------------------------------------------------- */
1670 /* Solve family (includes inv) */
1671
1672template<typename typ>
1673struct GESV_PARAMS_t
1674{
1675 typ *A; /* A is (N, N) of base type */
1676 typ *B; /* B is (N, NRHS) of base type */
1677 fortran_int * IPIV; /* IPIV is (N) */
1678
1679 fortran_int N;
1680 fortran_int NRHS;
1681 fortran_int LDA;
1682 fortran_int LDB;
1683};
1684
1685static inline fortran_int
1686call_gesv(GESV_PARAMS_t<fortran_real> *params)
1687{
1688 fortran_int rv;
1689 LAPACK(sgesv)(¶ms->N, ¶ms->NRHS,
1690 params->A, ¶ms->LDA,
1691 params->IPIV,
1692 params->B, ¶ms->LDB,
1693 &rv);
1694 return rv;
1695}
1696
1697static inline fortran_int
1698call_gesv(GESV_PARAMS_t<fortran_doublereal> *params)
1699{
1700 fortran_int rv;
1701 LAPACK(dgesv)(¶ms->N, ¶ms->NRHS,
1702 params->A, ¶ms->LDA,
1703 params->IPIV,
1704 params->B, ¶ms->LDB,
1705 &rv);
1706 return rv;
1707}
1708
1709static inline fortran_int
1710call_gesv(GESV_PARAMS_t<fortran_complex> *params)
1711{
1712 fortran_int rv;
1713 LAPACK(cgesv)(¶ms->N, ¶ms->NRHS,
1714 params->A, ¶ms->LDA,
1715 params->IPIV,
1716 params->B, ¶ms->LDB,
1717 &rv);
1718 return rv;
1719}
1720
1721static inline fortran_int
1722call_gesv(GESV_PARAMS_t<fortran_doublecomplex> *params)
1723{
1724 fortran_int rv;
1725 LAPACK(zgesv)(¶ms->N, ¶ms->NRHS,
1726 params->A, ¶ms->LDA,
1727 params->IPIV,
1728 params->B, ¶ms->LDB,
1729 &rv);
1730 return rv;
1731}
1732
1733
1734/*
1735 * Initialize the parameters to use in for the lapack function _heev
1736 * Handles buffer allocation
1737 */
1738template<typename ftyp>
1739static inline int
1740init_gesv(GESV_PARAMS_t<ftyp> *params, fortran_int N, fortran_int NRHS)
1741{
1742 npy_uint8 *mem_buff = NULL;
1743 npy_uint8 *a, *b, *ipiv;
1744 size_t safe_N = N;
1745 size_t safe_NRHS = NRHS;
1746 fortran_int ld = fortran_int_max(N, 1);
1747 mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(ftyp) +
1748 safe_N * safe_NRHS*sizeof(ftyp) +
1749 safe_N * sizeof(fortran_int));
1750 if (!mem_buff) {
1751 goto error;
1752 }
1753 a = mem_buff;
1754 b = a + safe_N * safe_N * sizeof(ftyp);
1755 ipiv = b + safe_N * safe_NRHS * sizeof(ftyp);
1756
1757 params->A = (ftyp*)a;
1758 params->B = (ftyp*)b;
1759 params->IPIV = (fortran_int*)ipiv;
1760 params->N = N;
1761 params->NRHS = NRHS;
1762 params->LDA = ld;
1763 params->LDB = ld;
1764
1765 return 1;
1766 error:
1767 free(mem_buff);
1768 memset(params, 0, sizeof(*params));
1769
1770 return 0;
1771}
1772
1773template<typename ftyp>
1774static inline void
1775release_gesv(GESV_PARAMS_t<ftyp> *params)
1776{
1777 /* memory block base is in A */
1778 free(params->A);
1779 memset(params, 0, sizeof(*params));
1780}
1781
1782template<typename typ>
1783static void
1784solve(char **args, npy_intp const *dimensions, npy_intp const *steps,
1785 void *NPY_UNUSED(func))
1786{
1787using ftyp = fortran_type_t<typ>;
1788 GESV_PARAMS_t<ftyp> params;
1789 fortran_int n, nrhs;
1790 int error_occurred = get_fp_invalid_and_clear();
1791 INIT_OUTER_LOOP_3
1792
1793 n = (fortran_int)dimensions[0];
1794 nrhs = (fortran_int)dimensions[1];
1795 if (init_gesv(¶ms, n, nrhs)) {
1796 LINEARIZE_DATA_t a_in, b_in, r_out;
1797
1798 init_linearize_data(&a_in, n, n, steps[1], steps[0]);
1799 init_linearize_data(&b_in, nrhs, n, steps[3], steps[2]);
1800 init_linearize_data(&r_out, nrhs, n, steps[5], steps[4]);
1801
1802 BEGIN_OUTER_LOOP_3
1803 int not_ok;
1804 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in);
1805 linearize_matrix((typ*)params.B, (typ*)args[1], &b_in);
1806 not_ok =call_gesv(¶ms);
1807 if (!not_ok) {
1808 delinearize_matrix((typ*)args[2], (typ*)params.B, &r_out);
1809 } else {
1810 error_occurred = 1;
1811 nan_matrix((typ*)args[2], &r_out);
1812 }
1813 END_OUTER_LOOP
1814
1815 release_gesv(¶ms);
1816 }
1817
1818 set_fp_invalid_or_clear(error_occurred);
1819}
1820
1821
1822template<typename typ>
1823static void
1824solve1(char **args, npy_intp const *dimensions, npy_intp const *steps,
1825 void *NPY_UNUSED(func))
1826{
1827using ftyp = fortran_type_t<typ>;
1828 GESV_PARAMS_t<ftyp> params;
1829 int error_occurred = get_fp_invalid_and_clear();
1830 fortran_int n;
1831 INIT_OUTER_LOOP_3
1832
1833 n = (fortran_int)dimensions[0];
1834 if (init_gesv(¶ms, n, 1)) {
1835 LINEARIZE_DATA_t a_in, b_in, r_out;
1836 init_linearize_data(&a_in, n, n, steps[1], steps[0]);
1837 init_linearize_data(&b_in, 1, n, 1, steps[2]);
1838 init_linearize_data(&r_out, 1, n, 1, steps[3]);
1839
1840 BEGIN_OUTER_LOOP_3
1841 int not_ok;
1842 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in);
1843 linearize_matrix((typ*)params.B, (typ*)args[1], &b_in);
1844 not_ok = call_gesv(¶ms);
1845 if (!not_ok) {
1846 delinearize_matrix((typ*)args[2], (typ*)params.B, &r_out);
1847 } else {
1848 error_occurred = 1;
1849 nan_matrix((typ*)args[2], &r_out);
1850 }
1851 END_OUTER_LOOP
1852
1853 release_gesv(¶ms);
1854 }
1855
1856 set_fp_invalid_or_clear(error_occurred);
1857}
1858
1859template<typename typ>
1860static void
1861inv(char **args, npy_intp const *dimensions, npy_intp const *steps,
1862 void *NPY_UNUSED(func))
1863{
1864using ftyp = fortran_type_t<typ>;
1865 GESV_PARAMS_t<ftyp> params;
1866 fortran_int n;
1867 int error_occurred = get_fp_invalid_and_clear();
1868 INIT_OUTER_LOOP_2
1869
1870 n = (fortran_int)dimensions[0];
1871 if (init_gesv(¶ms, n, n)) {
1872 LINEARIZE_DATA_t a_in, r_out;
1873 init_linearize_data(&a_in, n, n, steps[1], steps[0]);
1874 init_linearize_data(&r_out, n, n, steps[3], steps[2]);
1875
1876 BEGIN_OUTER_LOOP_2
1877 int not_ok;
1878 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in);
1879 identity_matrix((typ*)params.B, n);
1880 not_ok = call_gesv(¶ms);
1881 if (!not_ok) {
1882 delinearize_matrix((typ*)args[1], (typ*)params.B, &r_out);
1883 } else {
1884 error_occurred = 1;
1885 nan_matrix((typ*)args[1], &r_out);
1886 }
1887 END_OUTER_LOOP
1888
1889 release_gesv(¶ms);
1890 }
1891
1892 set_fp_invalid_or_clear(error_occurred);
1893}
1894
1895
1896/* -------------------------------------------------------------------------- */
1897 /* Cholesky decomposition */
1898
1899template<typename typ>
1900struct POTR_PARAMS_t
1901{
1902 typ *A;
1903 fortran_int N;
1904 fortran_int LDA;
1905 char UPLO;
1906};
1907
1908
1909static inline fortran_int
1910call_potrf(POTR_PARAMS_t<fortran_real> *params)
1911{
1912 fortran_int rv;
1913 LAPACK(spotrf)(¶ms->UPLO,
1914 ¶ms->N, params->A, ¶ms->LDA,
1915 &rv);
1916 return rv;
1917}
1918
1919static inline fortran_int
1920call_potrf(POTR_PARAMS_t<fortran_doublereal> *params)
1921{
1922 fortran_int rv;
1923 LAPACK(dpotrf)(¶ms->UPLO,
1924 ¶ms->N, params->A, ¶ms->LDA,
1925 &rv);
1926 return rv;
1927}
1928
1929static inline fortran_int
1930call_potrf(POTR_PARAMS_t<fortran_complex> *params)
1931{
1932 fortran_int rv;
1933 LAPACK(cpotrf)(¶ms->UPLO,
1934 ¶ms->N, params->A, ¶ms->LDA,
1935 &rv);
1936 return rv;
1937}
1938
1939static inline fortran_int
1940call_potrf(POTR_PARAMS_t<fortran_doublecomplex> *params)
1941{
1942 fortran_int rv;
1943 LAPACK(zpotrf)(¶ms->UPLO,
1944 ¶ms->N, params->A, ¶ms->LDA,
1945 &rv);
1946 return rv;
1947}
1948
1949template<typename ftyp>
1950static inline int
1951init_potrf(POTR_PARAMS_t<ftyp> *params, char UPLO, fortran_int N)
1952{
1953 npy_uint8 *mem_buff = NULL;
1954 npy_uint8 *a;
1955 size_t safe_N = N;
1956 fortran_int lda = fortran_int_max(N, 1);
1957
1958 mem_buff = (npy_uint8 *)malloc(safe_N * safe_N * sizeof(ftyp));
1959 if (!mem_buff) {
1960 goto error;
1961 }
1962
1963 a = mem_buff;
1964
1965 params->A = (ftyp*)a;
1966 params->N = N;
1967 params->LDA = lda;
1968 params->UPLO = UPLO;
1969
1970 return 1;
1971 error:
1972 free(mem_buff);
1973 memset(params, 0, sizeof(*params));
1974
1975 return 0;
1976}
1977
1978template<typename ftyp>
1979static inline void
1980release_potrf(POTR_PARAMS_t<ftyp> *params)
1981{
1982 /* memory block base in A */
1983 free(params->A);
1984 memset(params, 0, sizeof(*params));
1985}
1986
1987template<typename typ>
1988static void
1989cholesky(char uplo, char **args, npy_intp const *dimensions, npy_intp const *steps)
1990{
1991 using ftyp = fortran_type_t<typ>;
1992 POTR_PARAMS_t<ftyp> params;
1993 int error_occurred = get_fp_invalid_and_clear();
1994 fortran_int n;
1995 INIT_OUTER_LOOP_2
1996
1997 assert(uplo == 'L');
1998
1999 n = (fortran_int)dimensions[0];
2000 if (init_potrf(¶ms, uplo, n)) {
2001 LINEARIZE_DATA_t a_in, r_out;
2002 init_linearize_data(&a_in, n, n, steps[1], steps[0]);
2003 init_linearize_data(&r_out, n, n, steps[3], steps[2]);
2004 BEGIN_OUTER_LOOP_2
2005 int not_ok;
2006 linearize_matrix(params.A, (ftyp*)args[0], &a_in);
2007 not_ok = call_potrf(¶ms);
2008 if (!not_ok) {
2009 triu_matrix(params.A, params.N);
2010 delinearize_matrix((ftyp*)args[1], params.A, &r_out);
2011 } else {
2012 error_occurred = 1;
2013 nan_matrix((ftyp*)args[1], &r_out);
2014 }
2015 END_OUTER_LOOP
2016 release_potrf(¶ms);
2017 }
2018
2019 set_fp_invalid_or_clear(error_occurred);
2020}
2021
2022template<typename typ>
2023static void
2024cholesky_lo(char **args, npy_intp const *dimensions, npy_intp const *steps,
2025 void *NPY_UNUSED(func))
2026{
2027 cholesky<typ>('L', args, dimensions, steps);
2028}
2029
2030/* -------------------------------------------------------------------------- */
2031 /* eig family */
2032
2033template<typename typ>
2034struct GEEV_PARAMS_t {
2035 typ *A;
2036 basetype_t<typ> *WR; /* RWORK in complex versions, REAL W buffer for (sd)geev*/
2037 typ *WI;
2038 typ *VLR; /* REAL VL buffers for _geev where _ is s, d */
2039 typ *VRR; /* REAL VR buffers for _geev where _ is s, d */
2040 typ *WORK;
2041 typ *W; /* final w */
2042 typ *VL; /* final vl */
2043 typ *VR; /* final vr */
2044
2045 fortran_int N;
2046 fortran_int LDA;
2047 fortran_int LDVL;
2048 fortran_int LDVR;
2049 fortran_int LWORK;
2050
2051 char JOBVL;
2052 char JOBVR;
2053};
2054
2055template<typename typ>
2056static inline void
2057dump_geev_params(const char *name, GEEV_PARAMS_t<typ>* params)
2058{
2059 TRACE_TXT("\n%s\n"
2060
2061 "\t%10s: %p\n"\
2062 "\t%10s: %p\n"\
2063 "\t%10s: %p\n"\
2064 "\t%10s: %p\n"\
2065 "\t%10s: %p\n"\
2066 "\t%10s: %p\n"\
2067 "\t%10s: %p\n"\
2068 "\t%10s: %p\n"\
2069 "\t%10s: %p\n"\
2070
2071 "\t%10s: %d\n"\
2072 "\t%10s: %d\n"\
2073 "\t%10s: %d\n"\
2074 "\t%10s: %d\n"\
2075 "\t%10s: %d\n"\
2076
2077 "\t%10s: %c\n"\
2078 "\t%10s: %c\n",
2079
2080 name,
2081
2082 "A", params->A,
2083 "WR", params->WR,
2084 "WI", params->WI,
2085 "VLR", params->VLR,
2086 "VRR", params->VRR,
2087 "WORK", params->WORK,
2088 "W", params->W,
2089 "VL", params->VL,
2090 "VR", params->VR,
2091
2092 "N", (int)params->N,
2093 "LDA", (int)params->LDA,
2094 "LDVL", (int)params->LDVL,
2095 "LDVR", (int)params->LDVR,
2096 "LWORK", (int)params->LWORK,
2097
2098 "JOBVL", params->JOBVL,
2099 "JOBVR", params->JOBVR);
2100}
2101
2102static inline fortran_int
2103call_geev(GEEV_PARAMS_t<float>* params)
2104{
2105 fortran_int rv;
2106 LAPACK(sgeev)(¶ms->JOBVL, ¶ms->JOBVR,
2107 ¶ms->N, params->A, ¶ms->LDA,
2108 params->WR, params->WI,
2109 params->VLR, ¶ms->LDVL,
2110 params->VRR, ¶ms->LDVR,
2111 params->WORK, ¶ms->LWORK,
2112 &rv);
2113 return rv;
2114}
2115
2116static inline fortran_int
2117call_geev(GEEV_PARAMS_t<double>* params)
2118{
2119 fortran_int rv;
2120 LAPACK(dgeev)(¶ms->JOBVL, ¶ms->JOBVR,
2121 ¶ms->N, params->A, ¶ms->LDA,
2122 params->WR, params->WI,
2123 params->VLR, ¶ms->LDVL,
2124 params->VRR, ¶ms->LDVR,
2125 params->WORK, ¶ms->LWORK,
2126 &rv);
2127 return rv;
2128}
2129
2130
2131template<typename typ>
2132static inline int
2133init_geev(GEEV_PARAMS_t<typ> *params, char jobvl, char jobvr, fortran_int n,
2134scalar_trait)
2135{
2136 npy_uint8 *mem_buff = NULL;
2137 npy_uint8 *mem_buff2 = NULL;
2138 npy_uint8 *a, *wr, *wi, *vlr, *vrr, *work, *w, *vl, *vr;
2139 size_t safe_n = n;
2140 size_t a_size = safe_n * safe_n * sizeof(typ);
2141 size_t wr_size = safe_n * sizeof(typ);
2142 size_t wi_size = safe_n * sizeof(typ);
2143 size_t vlr_size = jobvl=='V' ? safe_n * safe_n * sizeof(typ) : 0;
2144 size_t vrr_size = jobvr=='V' ? safe_n * safe_n * sizeof(typ) : 0;
2145 size_t w_size = wr_size*2;
2146 size_t vl_size = vlr_size*2;
2147 size_t vr_size = vrr_size*2;
2148 size_t work_count = 0;
2149 fortran_int ld = fortran_int_max(n, 1);
2150
2151 /* allocate data for known sizes (all but work) */
2152 mem_buff = (npy_uint8 *)malloc(a_size + wr_size + wi_size +
2153 vlr_size + vrr_size +
2154 w_size + vl_size + vr_size);
2155 if (!mem_buff) {
2156 goto error;
2157 }
2158
2159 a = mem_buff;
2160 wr = a + a_size;
2161 wi = wr + wr_size;
2162 vlr = wi + wi_size;
2163 vrr = vlr + vlr_size;
2164 w = vrr + vrr_size;
2165 vl = w + w_size;
2166 vr = vl + vl_size;
2167
2168 params->A = (typ*)a;
2169 params->WR = (typ*)wr;
2170 params->WI = (typ*)wi;
2171 params->VLR = (typ*)vlr;
2172 params->VRR = (typ*)vrr;
2173 params->W = (typ*)w;
2174 params->VL = (typ*)vl;
2175 params->VR = (typ*)vr;
2176 params->N = n;
2177 params->LDA = ld;
2178 params->LDVL = ld;
2179 params->LDVR = ld;
2180 params->JOBVL = jobvl;
2181 params->JOBVR = jobvr;
2182
2183 /* Work size query */
2184 {
2185 typ work_size_query;
2186
2187 params->LWORK = -1;
2188 params->WORK = &work_size_query;
2189
2190 if (call_geev(params) != 0) {
2191 goto error;
2192 }
2193
2194 work_count = (size_t)work_size_query;
2195 }
2196
2197 mem_buff2 = (npy_uint8 *)malloc(work_count*sizeof(typ));
2198 if (!mem_buff2) {
2199 goto error;
2200 }
2201 work = mem_buff2;
2202
2203 params->LWORK = (fortran_int)work_count;
2204 params->WORK = (typ*)work;
2205
2206 return 1;
2207 error:
2208 free(mem_buff2);
2209 free(mem_buff);
2210 memset(params, 0, sizeof(*params));
2211
2212 return 0;
2213}
2214
2215template<typename complextyp, typename typ>
2216static inline void
2217mk_complex_array_from_real(complextyp *c, const typ *re, size_t n)
2218{
2219 size_t iter;
2220 for (iter = 0; iter < n; ++iter) {
2221 c[iter].r = re[iter];
2222 c[iter].i = numeric_limits<typ>::zero;
2223 }
2224}
2225
2226template<typename complextyp, typename typ>
2227static inline void
2228mk_complex_array(complextyp *c,
2229 const typ *re,
2230 const typ *im,
2231 size_t n)
2232{
2233 size_t iter;
2234 for (iter = 0; iter < n; ++iter) {
2235 c[iter].r = re[iter];
2236 c[iter].i = im[iter];
2237 }
2238}
2239
2240template<typename complextyp, typename typ>
2241static inline void
2242mk_complex_array_conjugate_pair(complextyp *c,
2243 const typ *r,
2244 size_t n)
2245{
2246 size_t iter;
2247 for (iter = 0; iter < n; ++iter) {
2248 typ re = r[iter];
2249 typ im = r[iter+n];
2250 c[iter].r = re;
2251 c[iter].i = im;
2252 c[iter+n].r = re;
2253 c[iter+n].i = -im;
2254 }
2255}
2256
2257/*
2258 * make the complex eigenvectors from the real array produced by sgeev/zgeev.
2259 * c is the array where the results will be left.
2260 * r is the source array of reals produced by sgeev/zgeev
2261 * i is the eigenvalue imaginary part produced by sgeev/zgeev
2262 * n is so that the order of the matrix is n by n
2263 */
2264template<typename complextyp, typename typ>
2265static inline void
2266mk_geev_complex_eigenvectors(complextyp *c,
2267 const typ *r,
2268 const typ *i,
2269 size_t n)
2270{
2271 size_t iter = 0;
2272 while (iter < n)
2273 {
2274 if (i[iter] == numeric_limits<typ>::zero) {
2275 /* eigenvalue was real, eigenvectors as well... */
2276 mk_complex_array_from_real(c, r, n);
2277 c += n;
2278 r += n;
2279 iter ++;
2280 } else {
2281 /* eigenvalue was complex, generate a pair of eigenvectors */
2282 mk_complex_array_conjugate_pair(c, r, n);
2283 c += 2*n;
2284 r += 2*n;
2285 iter += 2;
2286 }
2287 }
2288}
2289
2290
2291template<typename complextyp, typename typ>
2292static inline void
2293process_geev_results(GEEV_PARAMS_t<typ> *params, scalar_trait)
2294{
2295 /* REAL versions of geev need the results to be translated
2296 * into complex versions. This is the way to deal with imaginary
2297 * results. In our gufuncs we will always return complex arrays!
2298 */
2299 mk_complex_array((complextyp*)params->W, (typ*)params->WR, (typ*)params->WI, params->N);
2300
2301 /* handle the eigenvectors */
2302 if ('V' == params->JOBVL) {
2303 mk_geev_complex_eigenvectors((complextyp*)params->VL, (typ*)params->VLR,
2304 (typ*)params->WI, params->N);
2305 }
2306 if ('V' == params->JOBVR) {
2307 mk_geev_complex_eigenvectors((complextyp*)params->VR, (typ*)params->VRR,
2308 (typ*)params->WI, params->N);
2309 }
2310}
2311
2312
2313static inline fortran_int
2314call_geev(GEEV_PARAMS_t<fortran_complex>* params)
2315{
2316 fortran_int rv;
2317
2318 LAPACK(cgeev)(¶ms->JOBVL, ¶ms->JOBVR,
2319 ¶ms->N, params->A, ¶ms->LDA,
2320 params->W,
2321 params->VL, ¶ms->LDVL,
2322 params->VR, ¶ms->LDVR,
2323 params->WORK, ¶ms->LWORK,
2324 params->WR, /* actually RWORK */
2325 &rv);
2326 return rv;
2327}
2328static inline fortran_int
2329call_geev(GEEV_PARAMS_t<fortran_doublecomplex>* params)
2330{
2331 fortran_int rv;
2332
2333 LAPACK(zgeev)(¶ms->JOBVL, ¶ms->JOBVR,
2334 ¶ms->N, params->A, ¶ms->LDA,
2335 params->W,
2336 params->VL, ¶ms->LDVL,
2337 params->VR, ¶ms->LDVR,
2338 params->WORK, ¶ms->LWORK,
2339 params->WR, /* actually RWORK */
2340 &rv);
2341 return rv;
2342}
2343
2344template<typename ftyp>
2345static inline int
2346init_geev(GEEV_PARAMS_t<ftyp>* params,
2347 char jobvl,
2348 char jobvr,
2349 fortran_int n, complex_trait)
2350{
2351using realtyp = basetype_t<ftyp>;
2352 npy_uint8 *mem_buff = NULL;
2353 npy_uint8 *mem_buff2 = NULL;
2354 npy_uint8 *a, *w, *vl, *vr, *work, *rwork;
2355 size_t safe_n = n;
2356 size_t a_size = safe_n * safe_n * sizeof(ftyp);
2357 size_t w_size = safe_n * sizeof(ftyp);
2358 size_t vl_size = jobvl=='V'? safe_n * safe_n * sizeof(ftyp) : 0;
2359 size_t vr_size = jobvr=='V'? safe_n * safe_n * sizeof(ftyp) : 0;
2360 size_t rwork_size = 2 * safe_n * sizeof(realtyp);
2361 size_t work_count = 0;
2362 size_t total_size = a_size + w_size + vl_size + vr_size + rwork_size;
2363 fortran_int ld = fortran_int_max(n, 1);
2364
2365 mem_buff = (npy_uint8 *)malloc(total_size);
2366 if (!mem_buff) {
2367 goto error;
2368 }
2369
2370 a = mem_buff;
2371 w = a + a_size;
2372 vl = w + w_size;
2373 vr = vl + vl_size;
2374 rwork = vr + vr_size;
2375
2376 params->A = (ftyp*)a;
2377 params->WR = (realtyp*)rwork;
2378 params->WI = NULL;
2379 params->VLR = NULL;
2380 params->VRR = NULL;
2381 params->VL = (ftyp*)vl;
2382 params->VR = (ftyp*)vr;
2383 params->W = (ftyp*)w;
2384 params->N = n;
2385 params->LDA = ld;
2386 params->LDVL = ld;
2387 params->LDVR = ld;
2388 params->JOBVL = jobvl;
2389 params->JOBVR = jobvr;
2390
2391 /* Work size query */
2392 {
2393 ftyp work_size_query;
2394
2395 params->LWORK = -1;
2396 params->WORK = &work_size_query;
2397
2398 if (call_geev(params) != 0) {
2399 goto error;
2400 }
2401
2402 work_count = (size_t) work_size_query.r;
2403 /* Fix a bug in lapack 3.0.0 */
2404 if(work_count == 0) work_count = 1;
2405 }
2406
2407 mem_buff2 = (npy_uint8 *)malloc(work_count*sizeof(ftyp));
2408 if (!mem_buff2) {
2409 goto error;
2410 }
2411
2412 work = mem_buff2;
2413
2414 params->LWORK = (fortran_int)work_count;
2415 params->WORK = (ftyp*)work;
2416
2417 return 1;
2418 error:
2419 free(mem_buff2);
2420 free(mem_buff);
2421 memset(params, 0, sizeof(*params));
2422
2423 return 0;
2424}
2425
2426template<typename complextyp, typename typ>
2427static inline void
2428process_geev_results(GEEV_PARAMS_t<typ> *NPY_UNUSED(params), complex_trait)
2429{
2430 /* nothing to do here, complex versions are ready to copy out */
2431}
2432
2433
2434
2435template<typename typ>
2436static inline void
2437release_geev(GEEV_PARAMS_t<typ> *params)
2438{
2439 free(params->WORK);
2440 free(params->A);
2441 memset(params, 0, sizeof(*params));
2442}
2443
2444template<typename fctype, typename ftype>
2445static inline void
2446eig_wrapper(char JOBVL,
2447 char JOBVR,
2448 char**args,
2449 npy_intp const *dimensions,
2450 npy_intp const *steps)
2451{
2452 ptrdiff_t outer_steps[4];
2453 size_t iter;
2454 size_t outer_dim = *dimensions++;
2455 size_t op_count = 2;
2456 int error_occurred = get_fp_invalid_and_clear();
2457 GEEV_PARAMS_t<ftype> geev_params;
2458
2459 assert(JOBVL == 'N');
2460
2461 STACK_TRACE;
2462 op_count += 'V'==JOBVL?1:0;
2463 op_count += 'V'==JOBVR?1:0;
2464
2465 for (iter = 0; iter < op_count; ++iter) {
2466 outer_steps[iter] = (ptrdiff_t) steps[iter];
2467 }
2468 steps += op_count;
2469
2470 if (init_geev(&geev_params,
2471 JOBVL, JOBVR,
2472 (fortran_int)dimensions[0], dispatch_scalar<ftype>())) {
2473 LINEARIZE_DATA_t a_in;
2474 LINEARIZE_DATA_t w_out;
2475 LINEARIZE_DATA_t vl_out;
2476 LINEARIZE_DATA_t vr_out;
2477
2478 init_linearize_data(&a_in,
2479 geev_params.N, geev_params.N,
2480 steps[1], steps[0]);
2481 steps += 2;
2482 init_linearize_data(&w_out,
2483 1, geev_params.N,
2484 0, steps[0]);
2485 steps += 1;
2486 if ('V' == geev_params.JOBVL) {
2487 init_linearize_data(&vl_out,
2488 geev_params.N, geev_params.N,
2489 steps[1], steps[0]);
2490 steps += 2;
2491 }
2492 if ('V' == geev_params.JOBVR) {
2493 init_linearize_data(&vr_out,
2494 geev_params.N, geev_params.N,
2495 steps[1], steps[0]);
2496 }
2497
2498 for (iter = 0; iter < outer_dim; ++iter) {
2499 int not_ok;
2500 char **arg_iter = args;
2501 /* copy the matrix in */
2502 linearize_matrix((ftype*)geev_params.A, (ftype*)*arg_iter++, &a_in);
2503 not_ok = call_geev(&geev_params);
2504
2505 if (!not_ok) {
2506 process_geev_results<fctype>(&geev_params,
2507dispatch_scalar<ftype>{});
2508 delinearize_matrix((fctype*)*arg_iter++,
2509 (fctype*)geev_params.W,
2510 &w_out);
2511
2512 if ('V' == geev_params.JOBVL) {
2513 delinearize_matrix((fctype*)*arg_iter++,
2514 (fctype*)geev_params.VL,
2515 &vl_out);
2516 }
2517 if ('V' == geev_params.JOBVR) {
2518 delinearize_matrix((fctype*)*arg_iter++,
2519 (fctype*)geev_params.VR,
2520 &vr_out);
2521 }
2522 } else {
2523 /* geev failed */
2524 error_occurred = 1;
2525 nan_matrix((fctype*)*arg_iter++, &w_out);
2526 if ('V' == geev_params.JOBVL) {
2527 nan_matrix((fctype*)*arg_iter++, &vl_out);
2528 }
2529 if ('V' == geev_params.JOBVR) {
2530 nan_matrix((fctype*)*arg_iter++, &vr_out);
2531 }
2532 }
2533 update_pointers((npy_uint8**)args, outer_steps, op_count);
2534 }
2535
2536 release_geev(&geev_params);
2537 }
2538
2539 set_fp_invalid_or_clear(error_occurred);
2540}
2541
2542template<typename fctype, typename ftype>
2543static void
2544eig(char **args,
2545 npy_intp const *dimensions,
2546 npy_intp const *steps,
2547 void *NPY_UNUSED(func))
2548{
2549 eig_wrapper<fctype, ftype>('N', 'V', args, dimensions, steps);
2550}
2551
2552template<typename fctype, typename ftype>
2553static void
2554eigvals(char **args,
2555 npy_intp const *dimensions,
2556 npy_intp const *steps,
2557 void *NPY_UNUSED(func))
2558{
2559 eig_wrapper<fctype, ftype>('N', 'N', args, dimensions, steps);
2560}
2561
2562
2563
2564/* -------------------------------------------------------------------------- */
2565 /* singular value decomposition */
2566
2567template<typename ftyp>
2568struct GESDD_PARAMS_t
2569{
2570 ftyp *A;
2571 basetype_t<ftyp> *S;
2572 ftyp *U;
2573 ftyp *VT;
2574 ftyp *WORK;
2575 basetype_t<ftyp> *RWORK;
2576 fortran_int *IWORK;
2577
2578 fortran_int M;
2579 fortran_int N;
2580 fortran_int LDA;
2581 fortran_int LDU;
2582 fortran_int LDVT;
2583 fortran_int LWORK;
2584 char JOBZ;
2585} ;
2586
2587
2588template<typename ftyp>
2589static inline void
2590dump_gesdd_params(const char *name,
2591 GESDD_PARAMS_t<ftyp> *params)
2592{
2593 TRACE_TXT("\n%s:\n"\
2594
2595 "%14s: %18p\n"\
2596 "%14s: %18p\n"\
2597 "%14s: %18p\n"\
2598 "%14s: %18p\n"\
2599 "%14s: %18p\n"\
2600 "%14s: %18p\n"\
2601 "%14s: %18p\n"\
2602
2603 "%14s: %18d\n"\
2604 "%14s: %18d\n"\
2605 "%14s: %18d\n"\
2606 "%14s: %18d\n"\
2607 "%14s: %18d\n"\
2608 "%14s: %18d\n"\
2609
2610 "%14s: %15c'%c'\n",
2611
2612 name,
2613
2614 "A", params->A,
2615 "S", params->S,
2616 "U", params->U,
2617 "VT", params->VT,
2618 "WORK", params->WORK,
2619 "RWORK", params->RWORK,
2620 "IWORK", params->IWORK,
2621
2622 "M", (int)params->M,
2623 "N", (int)params->N,
2624 "LDA", (int)params->LDA,
2625 "LDU", (int)params->LDU,
2626 "LDVT", (int)params->LDVT,
2627 "LWORK", (int)params->LWORK,
2628
2629 "JOBZ", ' ', params->JOBZ);
2630}
2631
2632static inline int
2633compute_urows_vtcolumns(char jobz,
2634 fortran_int m, fortran_int n,
2635 fortran_int *urows, fortran_int *vtcolumns)
2636{
2637 fortran_int min_m_n = fortran_int_min(m, n);
2638 switch(jobz)
2639 {
2640 case 'N':
2641 *urows = 0;
2642 *vtcolumns = 0;
2643 break;
2644 case 'A':
2645 *urows = m;
2646 *vtcolumns = n;
2647 break;
2648 case 'S':
2649 {
2650 *urows = min_m_n;
2651 *vtcolumns = min_m_n;
2652 }
2653 break;
2654 default:
2655 return 0;
2656 }
2657
2658 return 1;
2659}
2660
2661static inline fortran_int
2662call_gesdd(GESDD_PARAMS_t<fortran_real> *params)
2663{
2664 fortran_int rv;
2665 LAPACK(sgesdd)(¶ms->JOBZ, ¶ms->M, ¶ms->N,
2666 params->A, ¶ms->LDA,
2667 params->S,
2668 params->U, ¶ms->LDU,
2669 params->VT, ¶ms->LDVT,
2670 params->WORK, ¶ms->LWORK,
2671 (fortran_int*)params->IWORK,
2672 &rv);
2673 return rv;
2674}
2675static inline fortran_int
2676call_gesdd(GESDD_PARAMS_t<fortran_doublereal> *params)
2677{
2678 fortran_int rv;
2679 LAPACK(dgesdd)(¶ms->JOBZ, ¶ms->M, ¶ms->N,
2680 params->A, ¶ms->LDA,
2681 params->S,
2682 params->U, ¶ms->LDU,
2683 params->VT, ¶ms->LDVT,
2684 params->WORK, ¶ms->LWORK,
2685 (fortran_int*)params->IWORK,
2686 &rv);
2687 return rv;
2688}
2689
2690template<typename ftyp>
2691static inline int
2692init_gesdd(GESDD_PARAMS_t<ftyp> *params,
2693 char jobz,
2694 fortran_int m,
2695 fortran_int n, scalar_trait)
2696{
2697 npy_uint8 *mem_buff = NULL;
2698 npy_uint8 *mem_buff2 = NULL;
2699 npy_uint8 *a, *s, *u, *vt, *work, *iwork;
2700 size_t safe_m = m;
2701 size_t safe_n = n;
2702 size_t a_size = safe_m * safe_n * sizeof(ftyp);
2703 fortran_int min_m_n = fortran_int_min(m, n);
2704 size_t safe_min_m_n = min_m_n;
2705 size_t s_size = safe_min_m_n * sizeof(ftyp);
2706 fortran_int u_row_count, vt_column_count;
2707 size_t safe_u_row_count, safe_vt_column_count;
2708 size_t u_size, vt_size;
2709 fortran_int work_count;
2710 size_t work_size;
2711 size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int);
2712 fortran_int ld = fortran_int_max(m, 1);
2713
2714 if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) {
2715 goto error;
2716 }
2717
2718 safe_u_row_count = u_row_count;
2719 safe_vt_column_count = vt_column_count;
2720
2721 u_size = safe_u_row_count * safe_m * sizeof(ftyp);
2722 vt_size = safe_n * safe_vt_column_count * sizeof(ftyp);
2723
2724 mem_buff = (npy_uint8 *)malloc(a_size + s_size + u_size + vt_size + iwork_size);
2725
2726 if (!mem_buff) {
2727 goto error;
2728 }
2729
2730 a = mem_buff;
2731 s = a + a_size;
2732 u = s + s_size;
2733 vt = u + u_size;
2734 iwork = vt + vt_size;
2735
2736 /* fix vt_column_count so that it is a valid lapack parameter (0 is not) */
2737 vt_column_count = fortran_int_max(1, vt_column_count);
2738
2739 params->M = m;
2740 params->N = n;
2741 params->A = (ftyp*)a;
2742 params->S = (ftyp*)s;
2743 params->U = (ftyp*)u;
2744 params->VT = (ftyp*)vt;
2745 params->RWORK = NULL;
2746 params->IWORK = (fortran_int*)iwork;
2747 params->LDA = ld;
2748 params->LDU = ld;
2749 params->LDVT = vt_column_count;
2750 params->JOBZ = jobz;
2751
2752 /* Work size query */
2753 {
2754 ftyp work_size_query;
2755
2756 params->LWORK = -1;
2757 params->WORK = &work_size_query;
2758
2759 if (call_gesdd(params) != 0) {
2760 goto error;
2761 }
2762
2763 work_count = (fortran_int)work_size_query;
2764 /* Fix a bug in lapack 3.0.0 */
2765 if(work_count == 0) work_count = 1;
2766 work_size = (size_t)work_count * sizeof(ftyp);
2767 }
2768
2769 mem_buff2 = (npy_uint8 *)malloc(work_size);
2770 if (!mem_buff2) {
2771 goto error;
2772 }
2773
2774 work = mem_buff2;
2775
2776 params->LWORK = work_count;
2777 params->WORK = (ftyp*)work;
2778
2779 return 1;
2780 error:
2781 TRACE_TXT("%s failed init\n", __FUNCTION__);
2782 free(mem_buff);
2783 free(mem_buff2);
2784 memset(params, 0, sizeof(*params));
2785
2786 return 0;
2787}
2788
2789static inline fortran_int
2790call_gesdd(GESDD_PARAMS_t<fortran_complex> *params)
2791{
2792 fortran_int rv;
2793 LAPACK(cgesdd)(¶ms->JOBZ, ¶ms->M, ¶ms->N,
2794 params->A, ¶ms->LDA,
2795 params->S,
2796 params->U, ¶ms->LDU,
2797 params->VT, ¶ms->LDVT,
2798 params->WORK, ¶ms->LWORK,
2799 params->RWORK,
2800 params->IWORK,
2801 &rv);
2802 return rv;
2803}
2804static inline fortran_int
2805call_gesdd(GESDD_PARAMS_t<fortran_doublecomplex> *params)
2806{
2807 fortran_int rv;
2808 LAPACK(zgesdd)(¶ms->JOBZ, ¶ms->M, ¶ms->N,
2809 params->A, ¶ms->LDA,
2810 params->S,
2811 params->U, ¶ms->LDU,
2812 params->VT, ¶ms->LDVT,
2813 params->WORK, ¶ms->LWORK,
2814 params->RWORK,
2815 params->IWORK,
2816 &rv);
2817 return rv;
2818}
2819
2820template<typename ftyp>
2821static inline int
2822init_gesdd(GESDD_PARAMS_t<ftyp> *params,
2823 char jobz,
2824 fortran_int m,
2825 fortran_int n, complex_trait)
2826{
2827using frealtyp = basetype_t<ftyp>;
2828 npy_uint8 *mem_buff = NULL, *mem_buff2 = NULL;
2829 npy_uint8 *a,*s, *u, *vt, *work, *rwork, *iwork;
2830 size_t a_size, s_size, u_size, vt_size, work_size, rwork_size, iwork_size;
2831 size_t safe_u_row_count, safe_vt_column_count;
2832 fortran_int u_row_count, vt_column_count, work_count;
2833 size_t safe_m = m;
2834 size_t safe_n = n;
2835 fortran_int min_m_n = fortran_int_min(m, n);
2836 size_t safe_min_m_n = min_m_n;
2837 fortran_int ld = fortran_int_max(m, 1);
2838
2839 if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) {
2840 goto error;
2841 }
2842
2843 safe_u_row_count = u_row_count;
2844 safe_vt_column_count = vt_column_count;
2845
2846 a_size = safe_m * safe_n * sizeof(ftyp);
2847 s_size = safe_min_m_n * sizeof(frealtyp);
2848 u_size = safe_u_row_count * safe_m * sizeof(ftyp);
2849 vt_size = safe_n * safe_vt_column_count * sizeof(ftyp);
2850 rwork_size = 'N'==jobz?
2851 (7 * safe_min_m_n) :
2852 (5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n);
2853 rwork_size *= sizeof(ftyp);
2854 iwork_size = 8 * safe_min_m_n* sizeof(fortran_int);
2855
2856 mem_buff = (npy_uint8 *)malloc(a_size +
2857 s_size +
2858 u_size +
2859 vt_size +
2860 rwork_size +
2861 iwork_size);
2862 if (!mem_buff) {
2863 goto error;
2864 }
2865
2866 a = mem_buff;
2867 s = a + a_size;
2868 u = s + s_size;
2869 vt = u + u_size;
2870 rwork = vt + vt_size;
2871 iwork = rwork + rwork_size;
2872
2873 /* fix vt_column_count so that it is a valid lapack parameter (0 is not) */
2874 vt_column_count = fortran_int_max(1, vt_column_count);
2875
2876 params->A = (ftyp*)a;
2877 params->S = (frealtyp*)s;
2878 params->U = (ftyp*)u;
2879 params->VT = (ftyp*)vt;
2880 params->RWORK = (frealtyp*)rwork;
2881 params->IWORK = (fortran_int*)iwork;
2882 params->M = m;
2883 params->N = n;
2884 params->LDA = ld;
2885 params->LDU = ld;
2886 params->LDVT = vt_column_count;
2887 params->JOBZ = jobz;
2888
2889 /* Work size query */
2890 {
2891 ftyp work_size_query;
2892
2893 params->LWORK = -1;
2894 params->WORK = &work_size_query;
2895
2896 if (call_gesdd(params) != 0) {
2897 goto error;
2898 }
2899
2900 work_count = (fortran_int)(*(frealtyp*)&work_size_query);
2901 /* Fix a bug in lapack 3.0.0 */
2902 if(work_count == 0) work_count = 1;
2903 work_size = (size_t)work_count * sizeof(ftyp);
2904 }
2905
2906 mem_buff2 = (npy_uint8 *)malloc(work_size);
2907 if (!mem_buff2) {
2908 goto error;
2909 }
2910
2911 work = mem_buff2;
2912
2913 params->LWORK = work_count;
2914 params->WORK = (ftyp*)work;
2915
2916 return 1;
2917 error:
2918 TRACE_TXT("%s failed init\n", __FUNCTION__);
2919 free(mem_buff2);
2920 free(mem_buff);
2921 memset(params, 0, sizeof(*params));
2922
2923 return 0;
2924}
2925
2926template<typename typ>
2927static inline void
2928release_gesdd(GESDD_PARAMS_t<typ>* params)
2929{
2930 /* A and WORK contain allocated blocks */
2931 free(params->A);
2932 free(params->WORK);
2933 memset(params, 0, sizeof(*params));
2934}
2935
2936template<typename typ>
2937static inline void
2938svd_wrapper(char JOBZ,
2939 char **args,
2940 npy_intp const *dimensions,
2941 npy_intp const *steps)
2942{
2943using basetyp = basetype_t<typ>;
2944 ptrdiff_t outer_steps[4];
2945 int error_occurred = get_fp_invalid_and_clear();
2946 size_t iter;
2947 size_t outer_dim = *dimensions++;
2948 size_t op_count = (JOBZ=='N')?2:4;
2949 GESDD_PARAMS_t<typ> params;
2950
2951 for (iter = 0; iter < op_count; ++iter) {
2952 outer_steps[iter] = (ptrdiff_t) steps[iter];
2953 }
2954 steps += op_count;
2955
2956 if (init_gesdd(¶ms,
2957 JOBZ,
2958 (fortran_int)dimensions[0],
2959 (fortran_int)dimensions[1],
2960dispatch_scalar<typ>())) {
2961 LINEARIZE_DATA_t a_in, u_out, s_out, v_out;
2962 fortran_int min_m_n = params.M < params.N ? params.M : params.N;
2963
2964 init_linearize_data(&a_in, params.N, params.M, steps[1], steps[0]);
2965 if ('N' == params.JOBZ) {
2966 /* only the singular values are wanted */
2967 init_linearize_data(&s_out, 1, min_m_n, 0, steps[2]);
2968 } else {
2969 fortran_int u_columns, v_rows;
2970 if ('S' == params.JOBZ) {
2971 u_columns = min_m_n;
2972 v_rows = min_m_n;
2973 } else { /* JOBZ == 'A' */
2974 u_columns = params.M;
2975 v_rows = params.N;
2976 }
2977 init_linearize_data(&u_out,
2978 u_columns, params.M,
2979 steps[3], steps[2]);
2980 init_linearize_data(&s_out,
2981 1, min_m_n,
2982 0, steps[4]);
2983 init_linearize_data(&v_out,
2984 params.N, v_rows,
2985 steps[6], steps[5]);
2986 }
2987
2988 for (iter = 0; iter < outer_dim; ++iter) {
2989 int not_ok;
2990 /* copy the matrix in */
2991 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in);
2992 not_ok = call_gesdd(¶ms);
2993 if (!not_ok) {
2994 if ('N' == params.JOBZ) {
2995 delinearize_matrix((basetyp*)args[1], (basetyp*)params.S, &s_out);
2996 } else {
2997 if ('A' == params.JOBZ && min_m_n == 0) {
2998 /* Lapack has betrayed us and left these uninitialized,
2999 * so produce an identity matrix for whichever of u
3000 * and v is not empty.
3001 */
3002 identity_matrix((typ*)params.U, params.M);
3003 identity_matrix((typ*)params.VT, params.N);
3004 }
3005
3006 delinearize_matrix((typ*)args[1], (typ*)params.U, &u_out);
3007 delinearize_matrix((basetyp*)args[2], (basetyp*)params.S, &s_out);
3008 delinearize_matrix((typ*)args[3], (typ*)params.VT, &v_out);
3009 }
3010 } else {
3011 error_occurred = 1;
3012 if ('N' == params.JOBZ) {
3013 nan_matrix((basetyp*)args[1], &s_out);
3014 } else {
3015 nan_matrix((typ*)args[1], &u_out);
3016 nan_matrix((basetyp*)args[2], &s_out);
3017 nan_matrix((typ*)args[3], &v_out);
3018 }
3019 }
3020 update_pointers((npy_uint8**)args, outer_steps, op_count);
3021 }
3022
3023 release_gesdd(¶ms);
3024 }
3025
3026 set_fp_invalid_or_clear(error_occurred);
3027}
3028
3029
3030template<typename typ>
3031static void
3032svd_N(char **args,
3033 npy_intp const *dimensions,
3034 npy_intp const *steps,
3035 void *NPY_UNUSED(func))
3036{
3037 svd_wrapper<fortran_type_t<typ>>('N', args, dimensions, steps);
3038}
3039
3040template<typename typ>
3041static void
3042svd_S(char **args,
3043 npy_intp const *dimensions,
3044 npy_intp const *steps,
3045 void *NPY_UNUSED(func))
3046{
3047 svd_wrapper<fortran_type_t<typ>>('S', args, dimensions, steps);
3048}
3049
3050template<typename typ>
3051static void
3052svd_A(char **args,
3053 npy_intp const *dimensions,
3054 npy_intp const *steps,
3055 void *NPY_UNUSED(func))
3056{
3057 svd_wrapper<fortran_type_t<typ>>('A', args, dimensions, steps);
3058}
3059
3060/* -------------------------------------------------------------------------- */
3061 /* qr (modes - r, raw) */
3062
3063template<typename typ>
3064struct GEQRF_PARAMS_t
3065{
3066 fortran_int M;
3067 fortran_int N;
3068 typ *A;
3069 fortran_int LDA;
3070 typ* TAU;
3071 typ *WORK;
3072 fortran_int LWORK;
3073};
3074
3075
3076template<typename typ>
3077static inline void
3078dump_geqrf_params(const char *name,
3079 GEQRF_PARAMS_t<typ> *params)
3080{
3081 TRACE_TXT("\n%s:\n"\
3082
3083 "%14s: %18p\n"\
3084 "%14s: %18p\n"\
3085 "%14s: %18p\n"\
3086 "%14s: %18d\n"\
3087 "%14s: %18d\n"\
3088 "%14s: %18d\n"\
3089 "%14s: %18d\n",
3090
3091 name,
3092
3093 "A", params->A,
3094 "TAU", params->TAU,
3095 "WORK", params->WORK,
3096
3097 "M", (int)params->M,
3098 "N", (int)params->N,
3099 "LDA", (int)params->LDA,
3100 "LWORK", (int)params->LWORK);
3101}
3102
3103static inline fortran_int
3104call_geqrf(GEQRF_PARAMS_t<double> *params)
3105{
3106 fortran_int rv;
3107 LAPACK(dgeqrf)(¶ms->M, ¶ms->N,
3108 params->A, ¶ms->LDA,
3109 params->TAU,
3110 params->WORK, ¶ms->LWORK,
3111 &rv);
3112 return rv;
3113}
3114static inline fortran_int
3115call_geqrf(GEQRF_PARAMS_t<f2c_doublecomplex> *params)
3116{
3117 fortran_int rv;
3118 LAPACK(zgeqrf)(¶ms->M, ¶ms->N,
3119 params->A, ¶ms->LDA,
3120 params->TAU,
3121 params->WORK, ¶ms->LWORK,
3122 &rv);
3123 return rv;
3124}
3125
3126
3127static inline int
3128init_geqrf(GEQRF_PARAMS_t<fortran_doublereal> *params,
3129 fortran_int m,
3130 fortran_int n)
3131{
3132using ftyp = fortran_doublereal;
3133 npy_uint8 *mem_buff = NULL;
3134 npy_uint8 *mem_buff2 = NULL;
3135 npy_uint8 *a, *tau, *work;
3136 fortran_int min_m_n = fortran_int_min(m, n);
3137 size_t safe_min_m_n = min_m_n;
3138 size_t safe_m = m;
3139 size_t safe_n = n;
3140
3141 size_t a_size = safe_m * safe_n * sizeof(ftyp);
3142 size_t tau_size = safe_min_m_n * sizeof(ftyp);
3143
3144 fortran_int work_count;
3145 size_t work_size;
3146 fortran_int lda = fortran_int_max(1, m);
3147
3148 mem_buff = (npy_uint8 *)malloc(a_size + tau_size);
3149
3150 if (!mem_buff)
3151 goto error;
3152
3153 a = mem_buff;
3154 tau = a + a_size;
3155 memset(tau, 0, tau_size);
3156
3157
3158 params->M = m;
3159 params->N = n;
3160 params->A = (ftyp*)a;
3161 params->TAU = (ftyp*)tau;
3162 params->LDA = lda;
3163
3164 {
3165 /* compute optimal work size */
3166
3167 ftyp work_size_query;
3168
3169 params->WORK = &work_size_query;
3170 params->LWORK = -1;
3171
3172 if (call_geqrf(params) != 0)
3173 goto error;
3174
3175 work_count = (fortran_int) *(ftyp*) params->WORK;
3176
3177 }
3178
3179 params->LWORK = fortran_int_max(fortran_int_max(1, n), work_count);
3180
3181 work_size = (size_t) params->LWORK * sizeof(ftyp);
3182 mem_buff2 = (npy_uint8 *)malloc(work_size);
3183 if (!mem_buff2)
3184 goto error;
3185
3186 work = mem_buff2;
3187
3188 params->WORK = (ftyp*)work;
3189
3190 return 1;
3191 error:
3192 TRACE_TXT("%s failed init\n", __FUNCTION__);
3193 free(mem_buff);
3194 free(mem_buff2);
3195 memset(params, 0, sizeof(*params));
3196
3197 return 0;
3198}
3199
3200
3201static inline int
3202init_geqrf(GEQRF_PARAMS_t<fortran_doublecomplex> *params,
3203 fortran_int m,
3204 fortran_int n)
3205{
3206using ftyp = fortran_doublecomplex;
3207 npy_uint8 *mem_buff = NULL;
3208 npy_uint8 *mem_buff2 = NULL;
3209 npy_uint8 *a, *tau, *work;
3210 fortran_int min_m_n = fortran_int_min(m, n);
3211 size_t safe_min_m_n = min_m_n;
3212 size_t safe_m = m;
3213 size_t safe_n = n;
3214
3215 size_t a_size = safe_m * safe_n * sizeof(ftyp);
3216 size_t tau_size = safe_min_m_n * sizeof(ftyp);
3217
3218 fortran_int work_count;
3219 size_t work_size;
3220 fortran_int lda = fortran_int_max(1, m);
3221
3222 mem_buff = (npy_uint8 *)malloc(a_size + tau_size);
3223
3224 if (!mem_buff)
3225 goto error;
3226
3227 a = mem_buff;
3228 tau = a + a_size;
3229 memset(tau, 0, tau_size);
3230
3231
3232 params->M = m;
3233 params->N = n;
3234 params->A = (ftyp*)a;
3235 params->TAU = (ftyp*)tau;
3236 params->LDA = lda;
3237
3238 {
3239 /* compute optimal work size */
3240
3241 ftyp work_size_query;
3242
3243 params->WORK = &work_size_query;
3244 params->LWORK = -1;
3245
3246 if (call_geqrf(params) != 0)
3247 goto error;
3248
3249 work_count = (fortran_int) ((ftyp*)params->WORK)->r;
3250
3251 }
3252
3253 params->LWORK = fortran_int_max(fortran_int_max(1, n),
3254 work_count);
3255
3256 work_size = (size_t) params->LWORK * sizeof(ftyp);
3257
3258 mem_buff2 = (npy_uint8 *)malloc(work_size);
3259 if (!mem_buff2)
3260 goto error;
3261
3262 work = mem_buff2;
3263
3264 params->WORK = (ftyp*)work;
3265
3266 return 1;
3267 error:
3268 TRACE_TXT("%s failed init\n", __FUNCTION__);
3269 free(mem_buff);
3270 free(mem_buff2);
3271 memset(params, 0, sizeof(*params));
3272
3273 return 0;
3274}
3275
3276
3277template<typename ftyp>
3278static inline void
3279release_geqrf(GEQRF_PARAMS_t<ftyp>* params)
3280{
3281 /* A and WORK contain allocated blocks */
3282 free(params->A);
3283 free(params->WORK);
3284 memset(params, 0, sizeof(*params));
3285}
3286
3287template<typename typ>
3288static void
3289qr_r_raw(char **args, npy_intp const *dimensions, npy_intp const *steps,
3290 void *NPY_UNUSED(func))
3291{
3292using ftyp = fortran_type_t<typ>;
3293
3294 GEQRF_PARAMS_t<ftyp> params;
3295 int error_occurred = get_fp_invalid_and_clear();
3296 fortran_int n, m;
3297
3298 INIT_OUTER_LOOP_2
3299
3300 m = (fortran_int)dimensions[0];
3301 n = (fortran_int)dimensions[1];
3302
3303 if (init_geqrf(¶ms, m, n)) {
3304 LINEARIZE_DATA_t a_in, tau_out;
3305
3306 init_linearize_data(&a_in, n, m, steps[1], steps[0]);
3307 init_linearize_data(&tau_out, 1, fortran_int_min(m, n), 1, steps[2]);
3308
3309 BEGIN_OUTER_LOOP_2
3310 int not_ok;
3311 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in);
3312 not_ok = call_geqrf(¶ms);
3313 if (!not_ok) {
3314 delinearize_matrix((typ*)args[0], (typ*)params.A, &a_in);
3315 delinearize_matrix((typ*)args[1], (typ*)params.TAU, &tau_out);
3316 } else {
3317 error_occurred = 1;
3318 nan_matrix((typ*)args[1], &tau_out);
3319 }
3320 END_OUTER_LOOP
3321
3322 release_geqrf(¶ms);
3323 }
3324
3325 set_fp_invalid_or_clear(error_occurred);
3326}
3327
3328
3329/* -------------------------------------------------------------------------- */
3330 /* qr common code (modes - reduced and complete) */
3331
3332template<typename typ>
3333struct GQR_PARAMS_t
3334{
3335 fortran_int M;
3336 fortran_int MC;
3337 fortran_int MN;
3338 void* A;
3339 typ *Q;
3340 fortran_int LDA;
3341 typ* TAU;
3342 typ *WORK;
3343 fortran_int LWORK;
3344} ;
3345
3346static inline fortran_int
3347call_gqr(GQR_PARAMS_t<double> *params)
3348{
3349 fortran_int rv;
3350 LAPACK(dorgqr)(¶ms->M, ¶ms->MC, ¶ms->MN,
3351 params->Q, ¶ms->LDA,
3352 params->TAU,
3353 params->WORK, ¶ms->LWORK,
3354 &rv);
3355 return rv;
3356}
3357static inline fortran_int
3358call_gqr(GQR_PARAMS_t<f2c_doublecomplex> *params)
3359{
3360 fortran_int rv;
3361 LAPACK(zungqr)(¶ms->M, ¶ms->MC, ¶ms->MN,
3362 params->Q, ¶ms->LDA,
3363 params->TAU,
3364 params->WORK, ¶ms->LWORK,
3365 &rv);
3366 return rv;
3367}
3368
3369static inline int
3370init_gqr_common(GQR_PARAMS_t<fortran_doublereal> *params,
3371 fortran_int m,
3372 fortran_int n,
3373 fortran_int mc)
3374{
3375using ftyp = fortran_doublereal;
3376 npy_uint8 *mem_buff = NULL;
3377 npy_uint8 *mem_buff2 = NULL;
3378 npy_uint8 *a, *q, *tau, *work;
3379 fortran_int min_m_n = fortran_int_min(m, n);
3380 size_t safe_mc = mc;
3381 size_t safe_min_m_n = min_m_n;
3382 size_t safe_m = m;
3383 size_t safe_n = n;
3384 size_t a_size = safe_m * safe_n * sizeof(ftyp);
3385 size_t q_size = safe_m * safe_mc * sizeof(ftyp);
3386 size_t tau_size = safe_min_m_n * sizeof(ftyp);
3387
3388 fortran_int work_count;
3389 size_t work_size;
3390 fortran_int lda = fortran_int_max(1, m);
3391
3392 mem_buff = (npy_uint8 *)malloc(q_size + tau_size + a_size);
3393
3394 if (!mem_buff)
3395 goto error;
3396
3397 q = mem_buff;
3398 tau = q + q_size;
3399 a = tau + tau_size;
3400
3401
3402 params->M = m;
3403 params->MC = mc;
3404 params->MN = min_m_n;
3405 params->A = a;
3406 params->Q = (ftyp*)q;
3407 params->TAU = (ftyp*)tau;
3408 params->LDA = lda;
3409
3410 {
3411 /* compute optimal work size */
3412 ftyp work_size_query;
3413
3414 params->WORK = &work_size_query;
3415 params->LWORK = -1;
3416
3417 if (call_gqr(params) != 0)
3418 goto error;
3419
3420 work_count = (fortran_int) *(ftyp*) params->WORK;
3421
3422 }
3423
3424 params->LWORK = fortran_int_max(fortran_int_max(1, n), work_count);
3425
3426 work_size = (size_t) params->LWORK * sizeof(ftyp);
3427
3428 mem_buff2 = (npy_uint8 *)malloc(work_size);
3429 if (!mem_buff2)
3430 goto error;
3431
3432 work = mem_buff2;
3433
3434 params->WORK = (ftyp*)work;
3435
3436 return 1;
3437 error:
3438 TRACE_TXT("%s failed init\n", __FUNCTION__);
3439 free(mem_buff);
3440 free(mem_buff2);
3441 memset(params, 0, sizeof(*params));
3442
3443 return 0;
3444}
3445
3446
3447static inline int
3448init_gqr_common(GQR_PARAMS_t<fortran_doublecomplex> *params,
3449 fortran_int m,
3450 fortran_int n,
3451 fortran_int mc)
3452{
3453using ftyp=fortran_doublecomplex;
3454 npy_uint8 *mem_buff = NULL;
3455 npy_uint8 *mem_buff2 = NULL;
3456 npy_uint8 *a, *q, *tau, *work;
3457 fortran_int min_m_n = fortran_int_min(m, n);
3458 size_t safe_mc = mc;
3459 size_t safe_min_m_n = min_m_n;
3460 size_t safe_m = m;
3461 size_t safe_n = n;
3462
3463 size_t a_size = safe_m * safe_n * sizeof(ftyp);
3464 size_t q_size = safe_m * safe_mc * sizeof(ftyp);
3465 size_t tau_size = safe_min_m_n * sizeof(ftyp);
3466
3467 fortran_int work_count;
3468 size_t work_size;
3469 fortran_int lda = fortran_int_max(1, m);
3470
3471 mem_buff = (npy_uint8 *)malloc(q_size + tau_size + a_size);
3472
3473 if (!mem_buff)
3474 goto error;
3475
3476 q = mem_buff;
3477 tau = q + q_size;
3478 a = tau + tau_size;
3479
3480
3481 params->M = m;
3482 params->MC = mc;
3483 params->MN = min_m_n;
3484 params->A = a;
3485 params->Q = (ftyp*)q;
3486 params->TAU = (ftyp*)tau;
3487 params->LDA = lda;
3488
3489 {
3490 /* compute optimal work size */
3491 ftyp work_size_query;
3492
3493 params->WORK = &work_size_query;
3494 params->LWORK = -1;
3495
3496 if (call_gqr(params) != 0)
3497 goto error;
3498
3499 work_count = (fortran_int) ((ftyp*)params->WORK)->r;
3500
3501 }
3502
3503 params->LWORK = fortran_int_max(fortran_int_max(1, n),
3504 work_count);
3505
3506 work_size = (size_t) params->LWORK * sizeof(ftyp);
3507
3508 mem_buff2 = (npy_uint8 *)malloc(work_size);
3509 if (!mem_buff2)
3510 goto error;
3511
3512 work = mem_buff2;
3513
3514 params->WORK = (ftyp*)work;
3515 params->LWORK = work_count;
3516
3517 return 1;
3518 error:
3519 TRACE_TXT("%s failed init\n", __FUNCTION__);
3520 free(mem_buff);
3521 free(mem_buff2);
3522 memset(params, 0, sizeof(*params));
3523
3524 return 0;
3525}
3526
3527/* -------------------------------------------------------------------------- */
3528 /* qr (modes - reduced) */
3529
3530
3531template<typename typ>
3532static inline void
3533dump_gqr_params(const char *name,
3534 GQR_PARAMS_t<typ> *params)
3535{
3536 TRACE_TXT("\n%s:\n"\
3537
3538 "%14s: %18p\n"\
3539 "%14s: %18p\n"\
3540 "%14s: %18p\n"\
3541 "%14s: %18d\n"\
3542 "%14s: %18d\n"\
3543 "%14s: %18d\n"\
3544 "%14s: %18d\n"\
3545 "%14s: %18d\n",
3546
3547 name,
3548
3549 "Q", params->Q,
3550 "TAU", params->TAU,
3551 "WORK", params->WORK,
3552
3553 "M", (int)params->M,
3554 "MC", (int)params->MC,
3555 "MN", (int)params->MN,
3556 "LDA", (int)params->LDA,
3557 "LWORK", (int)params->LWORK);
3558}
3559
3560template<typename ftyp>
3561static inline int
3562init_gqr(GQR_PARAMS_t<ftyp> *params,
3563 fortran_int m,
3564 fortran_int n)
3565{
3566 return init_gqr_common(
3567 params, m, n,
3568 fortran_int_min(m, n));
3569}
3570
3571
3572template<typename typ>
3573static inline void
3574release_gqr(GQR_PARAMS_t<typ>* params)
3575{
3576 /* A and WORK contain allocated blocks */
3577 free(params->Q);
3578 free(params->WORK);
3579 memset(params, 0, sizeof(*params));
3580}
3581
3582template<typename typ>
3583static void
3584qr_reduced(char **args, npy_intp const *dimensions, npy_intp const *steps,
3585 void *NPY_UNUSED(func))
3586{
3587using ftyp = fortran_type_t<typ>;
3588 GQR_PARAMS_t<ftyp> params;
3589 int error_occurred = get_fp_invalid_and_clear();
3590 fortran_int n, m;
3591
3592 INIT_OUTER_LOOP_3
3593
3594 m = (fortran_int)dimensions[0];
3595 n = (fortran_int)dimensions[1];
3596
3597 if (init_gqr(¶ms, m, n)) {
3598 LINEARIZE_DATA_t a_in, tau_in, q_out;
3599
3600 init_linearize_data(&a_in, n, m, steps[1], steps[0]);
3601 init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]);
3602 init_linearize_data(&q_out, fortran_int_min(m, n), m, steps[4], steps[3]);
3603
3604 BEGIN_OUTER_LOOP_3
3605 int not_ok;
3606 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in);
3607 linearize_matrix((typ*)params.Q, (typ*)args[0], &a_in);
3608 linearize_matrix((typ*)params.TAU, (typ*)args[1], &tau_in);
3609 not_ok = call_gqr(¶ms);
3610 if (!not_ok) {
3611 delinearize_matrix((typ*)args[2], (typ*)params.Q, &q_out);
3612 } else {
3613 error_occurred = 1;
3614 nan_matrix((typ*)args[2], &q_out);
3615 }
3616 END_OUTER_LOOP
3617
3618 release_gqr(¶ms);
3619 }
3620
3621 set_fp_invalid_or_clear(error_occurred);
3622}
3623
3624/* -------------------------------------------------------------------------- */
3625 /* qr (modes - complete) */
3626
3627template<typename ftyp>
3628static inline int
3629init_gqr_complete(GQR_PARAMS_t<ftyp> *params,
3630 fortran_int m,
3631 fortran_int n)
3632{
3633 return init_gqr_common(params, m, n, m);
3634}
3635
3636
3637template<typename typ>
3638static void
3639qr_complete(char **args, npy_intp const *dimensions, npy_intp const *steps,
3640 void *NPY_UNUSED(func))
3641{
3642using ftyp = fortran_type_t<typ>;
3643 GQR_PARAMS_t<ftyp> params;
3644 int error_occurred = get_fp_invalid_and_clear();
3645 fortran_int n, m;
3646
3647 INIT_OUTER_LOOP_3
3648
3649 m = (fortran_int)dimensions[0];
3650 n = (fortran_int)dimensions[1];
3651
3652
3653 if (init_gqr_complete(¶ms, m, n)) {
3654 LINEARIZE_DATA_t a_in, tau_in, q_out;
3655
3656 init_linearize_data(&a_in, n, m, steps[1], steps[0]);
3657 init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]);
3658 init_linearize_data(&q_out, m, m, steps[4], steps[3]);
3659
3660 BEGIN_OUTER_LOOP_3
3661 int not_ok;
3662 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in);
3663 linearize_matrix((typ*)params.Q, (typ*)args[0], &a_in);
3664 linearize_matrix((typ*)params.TAU, (typ*)args[1], &tau_in);
3665 not_ok = call_gqr(¶ms);
3666 if (!not_ok) {
3667 delinearize_matrix((typ*)args[2], (typ*)params.Q, &q_out);
3668 } else {
3669 error_occurred = 1;
3670 nan_matrix((typ*)args[2], &q_out);
3671 }
3672 END_OUTER_LOOP
3673
3674 release_gqr(¶ms);
3675 }
3676
3677 set_fp_invalid_or_clear(error_occurred);
3678}
3679
3680/* -------------------------------------------------------------------------- */
3681 /* least squares */
3682
3683template<typename typ>
3684struct GELSD_PARAMS_t
3685{
3686 fortran_int M;
3687 fortran_int N;
3688 fortran_int NRHS;
3689 typ *A;
3690 fortran_int LDA;
3691 typ *B;
3692 fortran_int LDB;
3693 basetype_t<typ> *S;
3694 basetype_t<typ> *RCOND;
3695 fortran_int RANK;
3696 typ *WORK;
3697 fortran_int LWORK;
3698 basetype_t<typ> *RWORK;
3699 fortran_int *IWORK;
3700};
3701
3702template<typename typ>
3703static inline void
3704dump_gelsd_params(const char *name,
3705 GELSD_PARAMS_t<typ> *params)
3706{
3707 TRACE_TXT("\n%s:\n"\
3708
3709 "%14s: %18p\n"\
3710 "%14s: %18p\n"\
3711 "%14s: %18p\n"\
3712 "%14s: %18p\n"\
3713 "%14s: %18p\n"\
3714 "%14s: %18p\n"\
3715
3716 "%14s: %18d\n"\
3717 "%14s: %18d\n"\
3718 "%14s: %18d\n"\
3719 "%14s: %18d\n"\
3720 "%14s: %18d\n"\
3721 "%14s: %18d\n"\
3722 "%14s: %18d\n"\
3723
3724 "%14s: %18p\n",
3725
3726 name,
3727
3728 "A", params->A,
3729 "B", params->B,
3730 "S", params->S,
3731 "WORK", params->WORK,
3732 "RWORK", params->RWORK,
3733 "IWORK", params->IWORK,
3734
3735 "M", (int)params->M,
3736 "N", (int)params->N,
3737 "NRHS", (int)params->NRHS,
3738 "LDA", (int)params->LDA,
3739 "LDB", (int)params->LDB,
3740 "LWORK", (int)params->LWORK,
3741 "RANK", (int)params->RANK,
3742
3743 "RCOND", params->RCOND);
3744}
3745
3746static inline fortran_int
3747call_gelsd(GELSD_PARAMS_t<fortran_real> *params)
3748{
3749 fortran_int rv;
3750 LAPACK(sgelsd)(¶ms->M, ¶ms->N, ¶ms->NRHS,
3751 params->A, ¶ms->LDA,
3752 params->B, ¶ms->LDB,
3753 params->S,
3754 params->RCOND, ¶ms->RANK,
3755 params->WORK, ¶ms->LWORK,
3756 params->IWORK,
3757 &rv);
3758 return rv;
3759}
3760
3761
3762static inline fortran_int
3763call_gelsd(GELSD_PARAMS_t<fortran_doublereal> *params)
3764{
3765 fortran_int rv;
3766 LAPACK(dgelsd)(¶ms->M, ¶ms->N, ¶ms->NRHS,
3767 params->A, ¶ms->LDA,
3768 params->B, ¶ms->LDB,
3769 params->S,
3770 params->RCOND, ¶ms->RANK,
3771 params->WORK, ¶ms->LWORK,
3772 params->IWORK,
3773 &rv);
3774 return rv;
3775}
3776
3777
3778template<typename ftyp>
3779static inline int
3780init_gelsd(GELSD_PARAMS_t<ftyp> *params,
3781 fortran_int m,
3782 fortran_int n,
3783 fortran_int nrhs,
3784scalar_trait)
3785{
3786 npy_uint8 *mem_buff = NULL;
3787 npy_uint8 *mem_buff2 = NULL;
3788 npy_uint8 *a, *b, *s, *work, *iwork;
3789 fortran_int min_m_n = fortran_int_min(m, n);
3790 fortran_int max_m_n = fortran_int_max(m, n);
3791 size_t safe_min_m_n = min_m_n;
3792 size_t safe_max_m_n = max_m_n;
3793 size_t safe_m = m;
3794 size_t safe_n = n;
3795 size_t safe_nrhs = nrhs;
3796
3797 size_t a_size = safe_m * safe_n * sizeof(ftyp);
3798 size_t b_size = safe_max_m_n * safe_nrhs * sizeof(ftyp);
3799 size_t s_size = safe_min_m_n * sizeof(ftyp);
3800
3801 fortran_int work_count;
3802 size_t work_size;
3803 size_t iwork_size;
3804 fortran_int lda = fortran_int_max(1, m);
3805 fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n));
3806
3807 size_t msize = a_size + b_size + s_size;
3808 mem_buff = (npy_uint8 *)malloc(msize != 0 ? msize : 1);
3809
3810 if (!mem_buff) {
3811 goto no_memory;
3812 }
3813 a = mem_buff;
3814 b = a + a_size;
3815 s = b + b_size;
3816
3817 params->M = m;
3818 params->N = n;
3819 params->NRHS = nrhs;
3820 params->A = (ftyp*)a;
3821 params->B = (ftyp*)b;
3822 params->S = (ftyp*)s;
3823 params->LDA = lda;
3824 params->LDB = ldb;
3825
3826 {
3827 /* compute optimal work size */
3828 ftyp work_size_query;
3829 fortran_int iwork_size_query;
3830
3831 params->WORK = &work_size_query;
3832 params->IWORK = &iwork_size_query;
3833 params->RWORK = NULL;
3834 params->LWORK = -1;
3835
3836 if (call_gelsd(params) != 0) {
3837 goto error;
3838 }
3839 work_count = (fortran_int)work_size_query;
3840
3841 work_size = (size_t) work_size_query * sizeof(ftyp);
3842 iwork_size = (size_t)iwork_size_query * sizeof(fortran_int);
3843 }
3844
3845 mem_buff2 = (npy_uint8 *)malloc(work_size + iwork_size);
3846 if (!mem_buff2) {
3847 goto no_memory;
3848 }
3849 work = mem_buff2;
3850 iwork = work + work_size;
3851
3852 params->WORK = (ftyp*)work;
3853 params->RWORK = NULL;
3854 params->IWORK = (fortran_int*)iwork;
3855 params->LWORK = work_count;
3856
3857 return 1;
3858
3859 no_memory:
3860 NPY_ALLOW_C_API_DEF
3861 NPY_ALLOW_C_API;
3862 PyErr_NoMemory();
3863 NPY_DISABLE_C_API;
3864
3865 error:
3866 TRACE_TXT("%s failed init\n", __FUNCTION__);
3867 free(mem_buff);
3868 free(mem_buff2);
3869 memset(params, 0, sizeof(*params));
3870 return 0;
3871}
3872
3873static inline fortran_int
3874call_gelsd(GELSD_PARAMS_t<fortran_complex> *params)
3875{
3876 fortran_int rv;
3877 LAPACK(cgelsd)(¶ms->M, ¶ms->N, ¶ms->NRHS,
3878 params->A, ¶ms->LDA,
3879 params->B, ¶ms->LDB,
3880 params->S,
3881 params->RCOND, ¶ms->RANK,
3882 params->WORK, ¶ms->LWORK,
3883 params->RWORK, (fortran_int*)params->IWORK,
3884 &rv);
3885 return rv;
3886}
3887
3888static inline fortran_int
3889call_gelsd(GELSD_PARAMS_t<fortran_doublecomplex> *params)
3890{
3891 fortran_int rv;
3892 LAPACK(zgelsd)(¶ms->M, ¶ms->N, ¶ms->NRHS,
3893 params->A, ¶ms->LDA,
3894 params->B, ¶ms->LDB,
3895 params->S,
3896 params->RCOND, ¶ms->RANK,
3897 params->WORK, ¶ms->LWORK,
3898 params->RWORK, (fortran_int*)params->IWORK,
3899 &rv);
3900 return rv;
3901}
3902
3903
3904template<typename ftyp>
3905static inline int
3906init_gelsd(GELSD_PARAMS_t<ftyp> *params,
3907 fortran_int m,
3908 fortran_int n,
3909 fortran_int nrhs,
3910complex_trait)
3911{
3912using frealtyp = basetype_t<ftyp>;
3913 npy_uint8 *mem_buff = NULL;
3914 npy_uint8 *mem_buff2 = NULL;
3915 npy_uint8 *a, *b, *s, *work, *iwork, *rwork;
3916 fortran_int min_m_n = fortran_int_min(m, n);
3917 fortran_int max_m_n = fortran_int_max(m, n);
3918 size_t safe_min_m_n = min_m_n;
3919 size_t safe_max_m_n = max_m_n;
3920 size_t safe_m = m;
3921 size_t safe_n = n;
3922 size_t safe_nrhs = nrhs;
3923
3924 size_t a_size = safe_m * safe_n * sizeof(ftyp);
3925 size_t b_size = safe_max_m_n * safe_nrhs * sizeof(ftyp);
3926 size_t s_size = safe_min_m_n * sizeof(frealtyp);
3927
3928 fortran_int work_count;
3929 size_t work_size, rwork_size, iwork_size;
3930 fortran_int lda = fortran_int_max(1, m);
3931 fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n));
3932
3933 size_t msize = a_size + b_size + s_size;
3934 mem_buff = (npy_uint8 *)malloc(msize != 0 ? msize : 1);
3935
3936 if (!mem_buff) {
3937 goto no_memory;
3938 }
3939
3940 a = mem_buff;
3941 b = a + a_size;
3942 s = b + b_size;
3943
3944 params->M = m;
3945 params->N = n;
3946 params->NRHS = nrhs;
3947 params->A = (ftyp*)a;
3948 params->B = (ftyp*)b;
3949 params->S = (frealtyp*)s;
3950 params->LDA = lda;
3951 params->LDB = ldb;
3952
3953 {
3954 /* compute optimal work size */
3955 ftyp work_size_query;
3956 frealtyp rwork_size_query;
3957 fortran_int iwork_size_query;
3958
3959 params->WORK = &work_size_query;
3960 params->IWORK = &iwork_size_query;
3961 params->RWORK = &rwork_size_query;
3962 params->LWORK = -1;
3963
3964 if (call_gelsd(params) != 0) {
3965 goto error;
3966 }
3967
3968 work_count = (fortran_int)work_size_query.r;
3969
3970 work_size = (size_t )work_size_query.r * sizeof(ftyp);
3971 rwork_size = (size_t)rwork_size_query * sizeof(frealtyp);
3972 iwork_size = (size_t)iwork_size_query * sizeof(fortran_int);
3973 }
3974
3975 mem_buff2 = (npy_uint8 *)malloc(work_size + rwork_size + iwork_size);
3976 if (!mem_buff2) {
3977 goto no_memory;
3978 }
3979
3980 work = mem_buff2;
3981 rwork = work + work_size;
3982 iwork = rwork + rwork_size;
3983
3984 params->WORK = (ftyp*)work;
3985 params->RWORK = (frealtyp*)rwork;
3986 params->IWORK = (fortran_int*)iwork;
3987 params->LWORK = work_count;
3988
3989 return 1;
3990
3991 no_memory:
3992 NPY_ALLOW_C_API_DEF
3993 NPY_ALLOW_C_API;
3994 PyErr_NoMemory();
3995 NPY_DISABLE_C_API;
3996
3997 error:
3998 TRACE_TXT("%s failed init\n", __FUNCTION__);
3999 free(mem_buff);
4000 free(mem_buff2);
4001 memset(params, 0, sizeof(*params));
4002
4003 return 0;
4004}
4005
4006template<typename ftyp>
4007static inline void
4008release_gelsd(GELSD_PARAMS_t<ftyp>* params)
4009{
4010 /* A and WORK contain allocated blocks */
4011 free(params->A);
4012 free(params->WORK);
4013 memset(params, 0, sizeof(*params));
4014}
4015
4016/** Compute the squared l2 norm of a contiguous vector */
4017template<typename typ>
4018static basetype_t<typ>
4019abs2(typ *p, npy_intp n, scalar_trait) {
4020 npy_intp i;
4021 basetype_t<typ> res = 0;
4022 for (i = 0; i < n; i++) {
4023 typ el = p[i];
4024 res += el*el;
4025 }
4026 return res;
4027}
4028template<typename typ>
4029static basetype_t<typ>
4030abs2(typ *p, npy_intp n, complex_trait) {
4031 npy_intp i;
4032 basetype_t<typ> res = 0;
4033 for (i = 0; i < n; i++) {
4034 typ el = p[i];
4035 res += RE(&el)*RE(&el) + IM(&el)*IM(&el);
4036 }
4037 return res;
4038}
4039
4040
4041template<typename typ>
4042static void
4043lstsq(char **args, npy_intp const *dimensions, npy_intp const *steps,
4044 void *NPY_UNUSED(func))
4045{
4046using ftyp = fortran_type_t<typ>;
4047using basetyp = basetype_t<typ>;
4048 GELSD_PARAMS_t<ftyp> params;
4049 int error_occurred = get_fp_invalid_and_clear();
4050 fortran_int n, m, nrhs;
4051 fortran_int excess;
4052
4053 INIT_OUTER_LOOP_7
4054
4055 m = (fortran_int)dimensions[0];
4056 n = (fortran_int)dimensions[1];
4057 nrhs = (fortran_int)dimensions[2];
4058 excess = m - n;
4059
4060 if (init_gelsd(¶ms, m, n, nrhs, dispatch_scalar<ftyp>{})) {
4061 LINEARIZE_DATA_t a_in, b_in, x_out, s_out, r_out;
4062
4063 init_linearize_data(&a_in, n, m, steps[1], steps[0]);
4064 init_linearize_data_ex(&b_in, nrhs, m, steps[3], steps[2], fortran_int_max(n, m));
4065 init_linearize_data_ex(&x_out, nrhs, n, steps[5], steps[4], fortran_int_max(n, m));
4066 init_linearize_data(&r_out, 1, nrhs, 1, steps[6]);
4067 init_linearize_data(&s_out, 1, fortran_int_min(n, m), 1, steps[7]);
4068
4069 BEGIN_OUTER_LOOP_7
4070 int not_ok;
4071 linearize_matrix((typ*)params.A, (typ*)args[0], &a_in);
4072 linearize_matrix((typ*)params.B, (typ*)args[1], &b_in);
4073 params.RCOND = (basetyp*)args[2];
4074 not_ok = call_gelsd(¶ms);
4075 if (!not_ok) {
4076 delinearize_matrix((typ*)args[3], (typ*)params.B, &x_out);
4077 *(npy_int*) args[5] = params.RANK;
4078 delinearize_matrix((basetyp*)args[6], (basetyp*)params.S, &s_out);
4079
4080 /* Note that linalg.lstsq discards this when excess == 0 */
4081 if (excess >= 0 && params.RANK == n) {
4082 /* Compute the residuals as the square sum of each column */
4083 int i;
4084 char *resid = args[4];
4085 ftyp *components = (ftyp *)params.B + n;
4086 for (i = 0; i < nrhs; i++) {
4087 ftyp *vector = components + i*m;
4088 /* Numpy and fortran floating types are the same size,
4089 * so this cast is safe */
4090 basetyp abs = abs2((typ *)vector, excess,
4091dispatch_scalar<typ>{});
4092 memcpy(
4093 resid + i*r_out.column_strides,
4094 &abs, sizeof(abs));
4095 }
4096 }
4097 else {
4098 /* Note that this is always discarded by linalg.lstsq */
4099 nan_matrix((basetyp*)args[4], &r_out);
4100 }
4101 } else {
4102 error_occurred = 1;
4103 nan_matrix((typ*)args[3], &x_out);
4104 nan_matrix((basetyp*)args[4], &r_out);
4105 *(npy_int*) args[5] = -1;
4106 nan_matrix((basetyp*)args[6], &s_out);
4107 }
4108 END_OUTER_LOOP
4109
4110 release_gelsd(¶ms);
4111 }
4112
4113 set_fp_invalid_or_clear(error_occurred);
4114}
4115
4116#pragma GCC diagnostic pop
4117
4118/* -------------------------------------------------------------------------- */
4119 /* gufunc registration */
4120
4121static void *array_of_nulls[] = {
4122 (void *)NULL,
4123 (void *)NULL,
4124 (void *)NULL,
4125 (void *)NULL,
4126
4127 (void *)NULL,
4128 (void *)NULL,
4129 (void *)NULL,
4130 (void *)NULL,
4131
4132 (void *)NULL,
4133 (void *)NULL,
4134 (void *)NULL,
4135 (void *)NULL,
4136
4137 (void *)NULL,
4138 (void *)NULL,
4139 (void *)NULL,
4140 (void *)NULL
4141};
4142
4143#define FUNC_ARRAY_NAME(NAME) NAME ## _funcs
4144
4145#define GUFUNC_FUNC_ARRAY_REAL(NAME) \
4146 static PyUFuncGenericFunction \
4147 FUNC_ARRAY_NAME(NAME)[] = { \
4148 FLOAT_ ## NAME, \
4149 DOUBLE_ ## NAME \
4150 }
4151
4152#define GUFUNC_FUNC_ARRAY_REAL_COMPLEX(NAME) \
4153 static PyUFuncGenericFunction \
4154 FUNC_ARRAY_NAME(NAME)[] = { \
4155 FLOAT_ ## NAME, \
4156 DOUBLE_ ## NAME, \
4157 CFLOAT_ ## NAME, \
4158 CDOUBLE_ ## NAME \
4159 }
4160#define GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(NAME) \
4161 static PyUFuncGenericFunction \
4162 FUNC_ARRAY_NAME(NAME)[] = { \
4163 NAME<npy_float, npy_float>, \
4164 NAME<npy_double, npy_double>, \
4165 NAME<npy_cfloat, npy_float>, \
4166 NAME<npy_cdouble, npy_double> \
4167 }
4168#define GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(NAME) \
4169 static PyUFuncGenericFunction \
4170 FUNC_ARRAY_NAME(NAME)[] = { \
4171 NAME<npy_float>, \
4172 NAME<npy_double>, \
4173 NAME<npy_cfloat>, \
4174 NAME<npy_cdouble> \
4175 }
4176
4177/* There are problems with eig in complex single precision.
4178 * That kernel is disabled
4179 */
4180#define GUFUNC_FUNC_ARRAY_EIG(NAME) \
4181 static PyUFuncGenericFunction \
4182 FUNC_ARRAY_NAME(NAME)[] = { \
4183 NAME<fortran_complex,fortran_real>, \
4184 NAME<fortran_doublecomplex,fortran_doublereal>, \
4185 NAME<fortran_doublecomplex,fortran_doublecomplex> \
4186 }
4187
4188/* The single precision functions are not used at all,
4189 * due to input data being promoted to double precision
4190 * in Python, so they are not implemented here.
4191 */
4192#define GUFUNC_FUNC_ARRAY_QR(NAME) \
4193 static PyUFuncGenericFunction \
4194 FUNC_ARRAY_NAME(NAME)[] = { \
4195 DOUBLE_ ## NAME, \
4196 CDOUBLE_ ## NAME \
4197 }
4198#define GUFUNC_FUNC_ARRAY_QR__(NAME) \
4199 static PyUFuncGenericFunction \
4200 FUNC_ARRAY_NAME(NAME)[] = { \
4201 NAME<npy_double>, \
4202 NAME<npy_cdouble> \
4203 }
4204
4205
4206GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(slogdet);
4207GUFUNC_FUNC_ARRAY_REAL_COMPLEX_(det);
4208GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eighlo);
4209GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eighup);
4210GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eigvalshlo);
4211GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(eigvalshup);
4212GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(solve);
4213GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(solve1);
4214GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(inv);
4215GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(cholesky_lo);
4216GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_N);
4217GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_S);
4218GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(svd_A);
4219GUFUNC_FUNC_ARRAY_QR__(qr_r_raw);
4220GUFUNC_FUNC_ARRAY_QR__(qr_reduced);
4221GUFUNC_FUNC_ARRAY_QR__(qr_complete);
4222GUFUNC_FUNC_ARRAY_REAL_COMPLEX__(lstsq);
4223GUFUNC_FUNC_ARRAY_EIG(eig);
4224GUFUNC_FUNC_ARRAY_EIG(eigvals);
4225
4226static char equal_2_types[] = {
4227 NPY_FLOAT, NPY_FLOAT,
4228 NPY_DOUBLE, NPY_DOUBLE,
4229 NPY_CFLOAT, NPY_CFLOAT,
4230 NPY_CDOUBLE, NPY_CDOUBLE
4231};
4232
4233static char equal_3_types[] = {
4234 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
4235 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
4236 NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT,
4237 NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE
4238};
4239
4240/* second result is logdet, that will always be a REAL */
4241static char slogdet_types[] = {
4242 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
4243 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
4244 NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT,
4245 NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE
4246};
4247
4248static char eigh_types[] = {
4249 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
4250 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
4251 NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT,
4252 NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE
4253};
4254
4255static char eighvals_types[] = {
4256 NPY_FLOAT, NPY_FLOAT,
4257 NPY_DOUBLE, NPY_DOUBLE,
4258 NPY_CFLOAT, NPY_FLOAT,
4259 NPY_CDOUBLE, NPY_DOUBLE
4260};
4261
4262static char eig_types[] = {
4263 NPY_FLOAT, NPY_CFLOAT, NPY_CFLOAT,
4264 NPY_DOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
4265 NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE
4266};
4267
4268static char eigvals_types[] = {
4269 NPY_FLOAT, NPY_CFLOAT,
4270 NPY_DOUBLE, NPY_CDOUBLE,
4271 NPY_CDOUBLE, NPY_CDOUBLE
4272};
4273
4274static char svd_1_1_types[] = {
4275 NPY_FLOAT, NPY_FLOAT,
4276 NPY_DOUBLE, NPY_DOUBLE,
4277 NPY_CFLOAT, NPY_FLOAT,
4278 NPY_CDOUBLE, NPY_DOUBLE
4279};
4280
4281static char svd_1_3_types[] = {
4282 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
4283 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
4284 NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT,
4285 NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE
4286};
4287
4288/* A, tau */
4289static char qr_r_raw_types[] = {
4290 NPY_DOUBLE, NPY_DOUBLE,
4291 NPY_CDOUBLE, NPY_CDOUBLE,
4292};
4293
4294/* A, tau, q */
4295static char qr_reduced_types[] = {
4296 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
4297 NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
4298};
4299
4300/* A, tau, q */
4301static char qr_complete_types[] = {
4302 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
4303 NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
4304};
4305
4306/* A, b, rcond, x, resid, rank, s, */
4307static char lstsq_types[] = {
4308 NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT,
4309 NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
4310 NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT,
4311 NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
4312};
4313
4314typedef struct gufunc_descriptor_struct {
4315 const char *name;
4316 const char *signature;
4317 const char *doc;
4318 int ntypes;
4319 int nin;
4320 int nout;
4321 PyUFuncGenericFunction *funcs;
4322 char *types;
4323} GUFUNC_DESCRIPTOR_t;
4324
4325GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = {
4326 {
4327 "slogdet",
4328 "(m,m)->(),()",
4329 "slogdet on the last two dimensions and broadcast on the rest. \n"\
4330 "Results in two arrays, one with sign and the other with log of the"\
4331 " determinants. \n"\
4332 " \"(m,m)->(),()\" \n",
4333 4, 1, 2,
4334 FUNC_ARRAY_NAME(slogdet),
4335 slogdet_types
4336 },
4337 {
4338 "det",
4339 "(m,m)->()",
4340 "det of the last two dimensions and broadcast on the rest. \n"\
4341 " \"(m,m)->()\" \n",
4342 4, 1, 1,
4343 FUNC_ARRAY_NAME(det),
4344 equal_2_types
4345 },
4346 {
4347 "eigh_lo",
4348 "(m,m)->(m),(m,m)",
4349 "eigh on the last two dimension and broadcast to the rest, using"\
4350 " lower triangle \n"\
4351 "Results in a vector of eigenvalues and a matrix with the"\
4352 "eigenvectors. \n"\
4353 " \"(m,m)->(m),(m,m)\" \n",
4354 4, 1, 2,
4355 FUNC_ARRAY_NAME(eighlo),
4356 eigh_types
4357 },
4358 {
4359 "eigh_up",
4360 "(m,m)->(m),(m,m)",
4361 "eigh on the last two dimension and broadcast to the rest, using"\
4362 " upper triangle. \n"\
4363 "Results in a vector of eigenvalues and a matrix with the"\
4364 " eigenvectors. \n"\
4365 " \"(m,m)->(m),(m,m)\" \n",
4366 4, 1, 2,
4367 FUNC_ARRAY_NAME(eighup),
4368 eigh_types
4369 },
4370 {
4371 "eigvalsh_lo",
4372 "(m,m)->(m)",
4373 "eigh on the last two dimension and broadcast to the rest, using"\
4374 " lower triangle. \n"\
4375 "Results in a vector of eigenvalues and a matrix with the"\
4376 "eigenvectors. \n"\
4377 " \"(m,m)->(m)\" \n",
4378 4, 1, 1,
4379 FUNC_ARRAY_NAME(eigvalshlo),
4380 eighvals_types
4381 },
4382 {
4383 "eigvalsh_up",
4384 "(m,m)->(m)",
4385 "eigvalsh on the last two dimension and broadcast to the rest,"\
4386 " using upper triangle. \n"\
4387 "Results in a vector of eigenvalues and a matrix with the"\
4388 "eigenvectors.\n"\
4389 " \"(m,m)->(m)\" \n",
4390 4, 1, 1,
4391 FUNC_ARRAY_NAME(eigvalshup),
4392 eighvals_types
4393 },
4394 {
4395 "solve",
4396 "(m,m),(m,n)->(m,n)",
4397 "solve the system a x = b, on the last two dimensions, broadcast"\
4398 " to the rest. \n"\
4399 "Results in a matrices with the solutions. \n"\
4400 " \"(m,m),(m,n)->(m,n)\" \n",
4401 4, 2, 1,
4402 FUNC_ARRAY_NAME(solve),
4403 equal_3_types
4404 },
4405 {
4406 "solve1",
4407 "(m,m),(m)->(m)",
4408 "solve the system a x = b, for b being a vector, broadcast in"\
4409 " the outer dimensions. \n"\
4410 "Results in vectors with the solutions. \n"\
4411 " \"(m,m),(m)->(m)\" \n",
4412 4, 2, 1,
4413 FUNC_ARRAY_NAME(solve1),
4414 equal_3_types
4415 },
4416 {
4417 "inv",
4418 "(m, m)->(m, m)",
4419 "compute the inverse of the last two dimensions and broadcast"\
4420 " to the rest. \n"\
4421 "Results in the inverse matrices. \n"\
4422 " \"(m,m)->(m,m)\" \n",
4423 4, 1, 1,
4424 FUNC_ARRAY_NAME(inv),
4425 equal_2_types
4426 },
4427 {
4428 "cholesky_lo",
4429 "(m,m)->(m,m)",
4430 "cholesky decomposition of hermitian positive-definite matrices. \n"\
4431 "Broadcast to all outer dimensions. \n"\
4432 " \"(m,m)->(m,m)\" \n",
4433 4, 1, 1,
4434 FUNC_ARRAY_NAME(cholesky_lo),
4435 equal_2_types
4436 },
4437 {
4438 "svd_m",
4439 "(m,n)->(m)",
4440 "svd when n>=m. ",
4441 4, 1, 1,
4442 FUNC_ARRAY_NAME(svd_N),
4443 svd_1_1_types
4444 },
4445 {
4446 "svd_n",
4447 "(m,n)->(n)",
4448 "svd when n<=m",
4449 4, 1, 1,
4450 FUNC_ARRAY_NAME(svd_N),
4451 svd_1_1_types
4452 },
4453 {
4454 "svd_m_s",
4455 "(m,n)->(m,m),(m),(m,n)",
4456 "svd when m<=n",
4457 4, 1, 3,
4458 FUNC_ARRAY_NAME(svd_S),
4459 svd_1_3_types
4460 },
4461 {
4462 "svd_n_s",
4463 "(m,n)->(m,n),(n),(n,n)",
4464 "svd when m>=n",
4465 4, 1, 3,
4466 FUNC_ARRAY_NAME(svd_S),
4467 svd_1_3_types
4468 },
4469 {
4470 "svd_m_f",
4471 "(m,n)->(m,m),(m),(n,n)",
4472 "svd when m<=n",
4473 4, 1, 3,
4474 FUNC_ARRAY_NAME(svd_A),
4475 svd_1_3_types
4476 },
4477 {
4478 "svd_n_f",
4479 "(m,n)->(m,m),(n),(n,n)",
4480 "svd when m>=n",
4481 4, 1, 3,
4482 FUNC_ARRAY_NAME(svd_A),
4483 svd_1_3_types
4484 },
4485 {
4486 "eig",
4487 "(m,m)->(m),(m,m)",
4488 "eig on the last two dimension and broadcast to the rest. \n"\
4489 "Results in a vector with the eigenvalues and a matrix with the"\
4490 " eigenvectors. \n"\
4491 " \"(m,m)->(m),(m,m)\" \n",
4492 3, 1, 2,
4493 FUNC_ARRAY_NAME(eig),
4494 eig_types
4495 },
4496 {
4497 "eigvals",
4498 "(m,m)->(m)",
4499 "eigvals on the last two dimension and broadcast to the rest. \n"\
4500 "Results in a vector of eigenvalues. \n",
4501 3, 1, 1,
4502 FUNC_ARRAY_NAME(eigvals),
4503 eigvals_types
4504 },
4505 {
4506 "qr_r_raw_m",
4507 "(m,n)->(m)",
4508 "Compute TAU vector for the last two dimensions \n"\
4509 "and broadcast to the rest. For m <= n. \n",
4510 2, 1, 1,
4511 FUNC_ARRAY_NAME(qr_r_raw),
4512 qr_r_raw_types
4513 },
4514 {
4515 "qr_r_raw_n",
4516 "(m,n)->(n)",
4517 "Compute TAU vector for the last two dimensions \n"\
4518 "and broadcast to the rest. For m > n. \n",
4519 2, 1, 1,
4520 FUNC_ARRAY_NAME(qr_r_raw),
4521 qr_r_raw_types
4522 },
4523 {
4524 "qr_reduced",
4525 "(m,n),(k)->(m,k)",
4526 "Compute Q matrix for the last two dimensions \n"\
4527 "and broadcast to the rest. \n",
4528 2, 2, 1,
4529 FUNC_ARRAY_NAME(qr_reduced),
4530 qr_reduced_types
4531 },
4532 {
4533 "qr_complete",
4534 "(m,n),(n)->(m,m)",
4535 "Compute Q matrix for the last two dimensions \n"\
4536 "and broadcast to the rest. For m > n. \n",
4537 2, 2, 1,
4538 FUNC_ARRAY_NAME(qr_complete),
4539 qr_complete_types
4540 },
4541 {
4542 "lstsq_m",
4543 "(m,n),(m,nrhs),()->(n,nrhs),(nrhs),(),(m)",
4544 "least squares on the last two dimensions and broadcast to the rest. \n"\
4545 "For m <= n. \n",
4546 4, 3, 4,
4547 FUNC_ARRAY_NAME(lstsq),
4548 lstsq_types
4549 },
4550 {
4551 "lstsq_n",
4552 "(m,n),(m,nrhs),()->(n,nrhs),(nrhs),(),(n)",
4553 "least squares on the last two dimensions and broadcast to the rest. \n"\
4554 "For m >= n, meaning that residuals are produced. \n",
4555 4, 3, 4,
4556 FUNC_ARRAY_NAME(lstsq),
4557 lstsq_types
4558 }
4559};
4560
4561static int
4562addUfuncs(PyObject *dictionary) {
4563 PyObject *f;
4564 int i;
4565 const int gufunc_count = sizeof(gufunc_descriptors)/
4566 sizeof(gufunc_descriptors[0]);
4567 for (i = 0; i < gufunc_count; i++) {
4568 GUFUNC_DESCRIPTOR_t* d = &gufunc_descriptors[i];
4569 f = PyUFunc_FromFuncAndDataAndSignature(d->funcs,
4570 array_of_nulls,
4571 d->types,
4572 d->ntypes,
4573 d->nin,
4574 d->nout,
4575 PyUFunc_None,
4576 d->name,
4577 d->doc,
4578 0,
4579 d->signature);
4580 if (f == NULL) {
4581 return -1;
4582 }
4583#if 0
4584 dump_ufunc_object((PyUFuncObject*) f);
4585#endif
4586 int ret = PyDict_SetItemString(dictionary, d->name, f);
4587 Py_DECREF(f);
4588 if (ret < 0) {
4589 return -1;
4590 }
4591 }
4592 return 0;
4593}
4594
4595
4596
4597/* -------------------------------------------------------------------------- */
4598 /* Module initialization stuff */
4599
4600static PyMethodDef UMath_LinAlgMethods[] = {
4601 {NULL, NULL, 0, NULL} /* Sentinel */
4602};
4603
4604static struct PyModuleDef moduledef = {
4605 PyModuleDef_HEAD_INIT,
4606 UMATH_LINALG_MODULE_NAME,
4607 NULL,
4608 -1,
4609 UMath_LinAlgMethods,
4610 NULL,
4611 NULL,
4612 NULL,
4613 NULL
4614};
4615
4616PyMODINIT_FUNC PyInit__umath_linalg(void)
4617{
4618 PyObject *m;
4619 PyObject *d;
4620 PyObject *version;
4621
4622 m = PyModule_Create(&moduledef);
4623 if (m == NULL) {
4624 return NULL;
4625 }
4626
4627 import_array();
4628 import_ufunc();
4629
4630 d = PyModule_GetDict(m);
4631 if (d == NULL) {
4632 return NULL;
4633 }
4634
4635 version = PyUnicode_FromString(umath_linalg_version_string);
4636 if (version == NULL) {
4637 return NULL;
4638 }
4639 int ret = PyDict_SetItemString(d, "__version__", version);
4640 Py_DECREF(version);
4641 if (ret < 0) {
4642 return NULL;
4643 }
4644
4645 /* Load the ufunc operators into the module's namespace */
4646 if (addUfuncs(d) < 0) {
4647 return NULL;
4648 }
4649
4650#ifdef HAVE_BLAS_ILP64
4651 PyDict_SetItemString(d, "_ilp64", Py_True);
4652#else
4653 PyDict_SetItemString(d, "_ilp64", Py_False);
4654#endif
4655
4656 return m;
4657}