37
37
38
38
"""
39
39
40
+ import math
41
+
40
42
import dpctl .tensor as dpt
41
43
import dpctl .tensor ._tensor_elementwise_impl as ti
42
44
import dpctl .utils as dpu
@@ -481,24 +483,66 @@ def _get_padding(a_size, v_size, mode):
481
483
r_pad = v_size - l_pad - 1
482
484
elif mode == "full" :
483
485
l_pad , r_pad = v_size - 1 , v_size - 1
484
- else :
486
+ else : # pragma: no cover
485
487
raise ValueError (
486
488
f"Unknown mode: { mode } . Only 'valid', 'same', 'full' are supported."
487
489
)
488
490
489
491
return l_pad , r_pad
490
492
491
493
492
- def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad ):
494
+ def _choose_conv_method (a , v , rdtype ):
495
+ assert a .size >= v .size
496
+ if rdtype == dpnp .bool :
497
+ # to avoid accuracy issues
498
+ return "direct"
499
+
500
+ if v .size < 10 ** 4 or a .size < 10 ** 4 :
501
+ # direct method is faster for small arrays
502
+ return "direct"
503
+
504
+ if dpnp .issubdtype (rdtype , dpnp .integer ):
505
+ max_a = int (dpnp .max (dpnp .abs (a )))
506
+ sum_v = int (dpnp .sum (dpnp .abs (v )))
507
+ max_value = int (max_a * sum_v )
508
+
509
+ default_float = dpnp .default_float_type (a .sycl_device )
510
+ if max_value > 2 ** numpy .finfo (default_float ).nmant - 1 :
511
+ # can't represent the result in the default float type
512
+ return "direct" # pragma: no covers
513
+
514
+ if dpnp .issubdtype (rdtype , dpnp .number ):
515
+ return "fft"
516
+
517
+ raise ValueError (f"Unsupported dtype: { rdtype } " ) # pragma: no cover
518
+
519
+
520
+ def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype ):
493
521
queue = a .sycl_queue
522
+ device = a .sycl_device
523
+
524
+ supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
525
+ supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
526
+
527
+ if supported_dtype is None : # pragma: no cover
528
+ raise ValueError (
529
+ f"function does not support input types "
530
+ f"({ a .dtype .name } , { v .dtype .name } ), "
531
+ "and the inputs could not be coerced to any "
532
+ f"supported types. List of supported types: "
533
+ f"{ [st .name for st in supported_types ]} "
534
+ )
535
+
536
+ a_casted = dpnp .asarray (a , dtype = supported_dtype , order = "C" )
537
+ v_casted = dpnp .asarray (v , dtype = supported_dtype , order = "C" )
494
538
495
- usm_type = dpu .get_coerced_usm_type ([a .usm_type , v .usm_type ])
496
- out_size = l_pad + r_pad + a .size - v .size + 1
539
+ usm_type = dpu .get_coerced_usm_type ([a_casted .usm_type , v_casted .usm_type ])
540
+ out_size = l_pad + r_pad + a_casted .size - v_casted .size + 1
497
541
# out type is the same as input type
498
- out = dpnp .empty_like (a , shape = out_size , usm_type = usm_type )
542
+ out = dpnp .empty_like (a_casted , shape = out_size , usm_type = usm_type )
499
543
500
- a_usm = dpnp .get_usm_ndarray (a )
501
- v_usm = dpnp .get_usm_ndarray (v )
544
+ a_usm = dpnp .get_usm_ndarray (a_casted )
545
+ v_usm = dpnp .get_usm_ndarray (v_casted )
502
546
out_usm = dpnp .get_usm_ndarray (out )
503
547
504
548
_manager = dpu .SequentialOrderManager [queue ]
@@ -516,7 +560,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
516
560
return out
517
561
518
562
519
- def correlate (a , v , mode = "valid" ):
563
+ def _convolve_fft (a , v , l_pad , r_pad , rtype ):
564
+ assert a .size >= v .size
565
+ assert l_pad < v .size
566
+
567
+ # +1 is needed to avoid circular convolution
568
+ padded_size = a .size + r_pad + 1
569
+ fft_size = 2 ** int (math .ceil (math .log2 (padded_size )))
570
+
571
+ af = dpnp .fft .fft (a , fft_size ) # pylint: disable=no-member
572
+ vf = dpnp .fft .fft (v , fft_size ) # pylint: disable=no-member
573
+
574
+ r = dpnp .fft .ifft (af * vf ) # pylint: disable=no-member
575
+ if dpnp .issubdtype (rtype , dpnp .floating ):
576
+ r = r .real
577
+ elif dpnp .issubdtype (rtype , dpnp .integer ) or rtype == dpnp .bool :
578
+ r = r .real .round ()
579
+
580
+ start = v .size - 1 - l_pad
581
+ end = padded_size - 1
582
+
583
+ return r [start :end ]
584
+
585
+
586
+ def correlate (a , v , mode = "valid" , method = "auto" ):
520
587
r"""
521
588
Cross-correlation of two 1-dimensional sequences.
522
589
@@ -541,6 +608,20 @@ def correlate(a, v, mode="valid"):
541
608
is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
542
609
543
610
Default: ``"valid"``.
611
+ method : {"auto", "direct", "fft"}, optional
612
+ Specifies which method to use to calculate the correlation:
613
+
614
+ - `"direct"` : The correlation is determined directly from sums.
615
+ - `"fft"` : The Fourier Transform is used to perform the calculations.
616
+ This method is faster for long sequences but can have accuracy issues.
617
+ - `"auto"` : Automatically chooses direct or Fourier method based on
618
+ an estimate of which is faster.
619
+
620
+ Note: Use of the FFT convolution on input containing NAN or INF
621
+ will lead to the entire output being NAN or INF.
622
+ Use method='direct' when your input contains NAN or INF values.
623
+
624
+ Default: ``"auto"``.
544
625
545
626
Returns
546
627
-------
@@ -608,20 +689,14 @@ def correlate(a, v, mode="valid"):
608
689
f"Received shapes: a.shape={ a .shape } , v.shape={ v .shape } "
609
690
)
610
691
611
- supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
692
+ supported_methods = ["auto" , "direct" , "fft" ]
693
+ if method not in supported_methods :
694
+ raise ValueError (
695
+ f"Unknown method: { method } . Supported methods: { supported_methods } "
696
+ )
612
697
613
698
device = a .sycl_device
614
699
rdtype = result_type_for_device ([a .dtype , v .dtype ], device )
615
- supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
616
-
617
- if supported_dtype is None : # pragma: no cover
618
- raise ValueError (
619
- f"function does not support input types "
620
- f"({ a .dtype .name } , { v .dtype .name } ), "
621
- "and the inputs could not be coerced to any "
622
- f"supported types. List of supported types: "
623
- f"{ [st .name for st in supported_types ]} "
624
- )
625
700
626
701
if dpnp .issubdtype (v .dtype , dpnp .complexfloating ):
627
702
v = dpnp .conj (v )
@@ -633,10 +708,15 @@ def correlate(a, v, mode="valid"):
633
708
634
709
l_pad , r_pad = _get_padding (a .size , v .size , mode )
635
710
636
- a_casted = dpnp . asarray ( a , dtype = supported_dtype , order = "C" )
637
- v_casted = dpnp . asarray ( v , dtype = supported_dtype , order = "C" )
711
+ if method == "auto" :
712
+ method = _choose_conv_method ( a , v , rdtype )
638
713
639
- r = _run_native_sliding_dot_product1d (a_casted , v_casted , l_pad , r_pad )
714
+ if method == "direct" :
715
+ r = _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype )
716
+ elif method == "fft" :
717
+ r = _convolve_fft (a , v [::- 1 ], l_pad , r_pad , rdtype )
718
+ else : # pragma: no cover
719
+ raise ValueError (f"Unknown method: { method } " )
640
720
641
721
if revert :
642
722
r = r [::- 1 ]
0 commit comments