Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use more class-based FFTs in pycbc_live #4096

Merged
merged 18 commits into from
Aug 12, 2022
Merged
Prev Previous commit
Next Next commit
Significant rework of cached_FFT
  • Loading branch information
spxiwh committed Aug 10, 2022
commit 4a9a4e790a9333a56be758c74668971d1dfa8c9e
2 changes: 1 addition & 1 deletion pycbc/fft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@

from .parser_support import insert_fft_option_group, verify_fft_options, from_cli
from .func_api import fft, ifft
from .class_api import FFT, IFFT, create_memory_and_engine_for_class_based_fft
from .class_api import FFT, IFFT
from .backend_support import get_backend_names
59 changes: 0 additions & 59 deletions pycbc/fft/class_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,62 +88,3 @@ def __new__(cls, *args, **kwargs):
real_cls = _ifft_factory(*args, **kwargs)
return real_cls(*args, **kwargs)


def create_memory_and_engine_for_class_based_fft(
npoints_time,
dtype,
delta_t=1,
ifft=False
):
""" Create memory and engine for class-based FFT/IFFT

Currently only supports R2C FFT / C2R IFFTs, but this could be expanded
if use-cases arise.

Parameters
----------
npoints_time : int
Number of time samples of the real input vector (or real output vector
if doing an IFFT).
dtype : np.dtype
The dtype for the real input vector (or real output vector if doing an
IFFT). np.float32 or np.float64 I think in all cases.
delta_t : float (default 1)
delta_t of the real vector. If not given this will be set to 1, and we
will assume it is not needed in the returned TimeSeries/FrequencySeries
ifft : boolean (default False)
By default will use the FFT class, set to true to use IFFT.
"""
from pycbc.types import FrequencySeries, TimeSeries, zeros
from pycbc.types import complex_same_precision_as

npoints_freq = npoints_time // 2 + 1
delta_f_tmp = 1.0 / (npoints_time * delta_t)
vec = TimeSeries(
zeros(
npoints_time,
dtype=dtype
),
delta_t=delta_t,
copy=False
)
vectilde = FrequencySeries(
zeros(
npoints_freq,
dtype=complex_same_precision_as(vec)
),
delta_f=delta_f_tmp,
copy=False
)
if ifft:
fft_class = IFFT(vectilde, vec)
invec = vectilde
outvec = vec
else:
fft_class = FFT(vec, vectilde)
invec = vec
outvec = vectilde

return invec, outvec, fft_class


51 changes: 21 additions & 30 deletions pycbc/filter/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import scipy.signal
from pycbc.types import TimeSeries, Array, zeros, FrequencySeries, real_same_precision_as
from pycbc.types import complex_same_precision_as
from pycbc.fft import ifft, fft, create_memory_and_engine_for_class_based_fft
from pycbc.fft import ifft, fft

_resample_func = {numpy.dtype('float32'): lal.ResampleREAL4TimeSeries,
numpy.dtype('float64'): lal.ResampleREAL8TimeSeries}
Expand Down Expand Up @@ -72,6 +72,9 @@ def lfilter(coefficients, timeseries):
epoch=timeseries.start_time,
delta_t=timeseries.delta_t)
elif (len(timeseries) < fillen * 10) or (len(timeseries) < 2**18):
from pycbc.strain.strain import create_memory_and_engine_for_class_based_fft
from pycbc.strain.strain import execute_cached_fft, execute_cached_ifft

cseries = (Array(coefficients[::-1] * 1)).astype(timeseries.dtype)
cseries.resize(len(timeseries))
cseries.roll(len(timeseries) - fillen + 1)
Expand All @@ -91,39 +94,27 @@ def lfilter(coefficients, timeseries):

else:
npoints = len(cseries)
if (npoints, ftype) not in fft_cache:
fft1outs = create_memory_and_engine_for_class_based_fft(
npoints,
timeseries.dtype,
ifft=False
)

fft2outs = create_memory_and_engine_for_class_based_fft(
npoints,
timeseries.dtype,
ifft=False
)

ifftouts = create_memory_and_engine_for_class_based_fft(
npoints,
timeseries.dtype,
ifft=True
)

fft_cache[(npoints, ftype)] = [fft1outs, fft2outs, ifftouts]

fft1outs, fft2outs, ifftouts = fft_cache[(npoints, ftype)]
vec, cfreq, fft_class = fft1outs
vec._data[:] = cseries._data[:]
fft_class.execute()

vec, tfreq, fft_class = fft2outs
vec._data[:] = timeseries._data[:]
fft_class.execute()
# NOTE: This function is cached!
ifftouts = create_memory_and_engine_for_class_based_fft(
npoints,
timeseries.dtype,
ifft=True,
uid=486876761
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming that the number you select for the uid here is just so that it's different from anything else? If there's a reason it has this value (also below) then please add a code comment explaining what determines the value.

If it's just to be unique from other calls of the same function, would it be clearer to just define some constants in this file, like LFILTER_UNIQUE_ID = 1, or whatever?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Josh. This is just to give different outputs if called with the same inputs from different functions. The number just has to be unique. I made this a constant, with some inline comments, as suggested.


# FFT contents of cseries into cfreq
cfreq = execute_cached_fft(cseries, uid=46464651,
normalize_by_rate=False)

# FFT contents of timeseries into tfreq
tfreq = execute_cached_fft(timeseries, uid=91236752,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above about magic number UID.

normalize_by_rate=False)

cout, out, fft_class = ifftouts

# Correlate cfreq and tfreq
correlate(cfreq, tfreq, cout)
# IFFT correlation output into out
fft_class.execute()

return TimeSeries(out.numpy() / len(out), epoch=timeseries.start_time,
Expand Down
180 changes: 153 additions & 27 deletions pycbc/strain/strain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
import copy
import logging, numpy
import functools
import pycbc.types
from pycbc.types import TimeSeries, zeros
from pycbc.types import Array, FrequencySeries
Expand All @@ -31,7 +32,7 @@
from pycbc.filter.zpk import filter_zpk
from pycbc.waveform.spa_tmplt import spa_distance
import pycbc.psd
from pycbc.fft import create_memory_and_engine_for_class_based_fft
from pycbc.fft import FFT, IFFT
import pycbc.events
import pycbc.frame
import pycbc.filter
Expand Down Expand Up @@ -1245,6 +1246,143 @@ def verify_segment_options_multi_ifo(cls, opt, parser, ifos):
required_opts_multi_ifo(opt, parser, ifo, cls.required_opts_list)


@functools.lru_cache(maxsize=500)
def create_memory_and_engine_for_class_based_fft(
npoints_time,
dtype,
delta_t=1,
ifft=False,
uid=0
):
""" Create memory and engine for class-based FFT/IFFT

Currently only supports R2C FFT / C2R IFFTs, but this could be expanded
if use-cases arise.

Parameters
----------
npoints_time : int
Number of time samples of the real input vector (or real output vector
if doing an IFFT).
dtype : np.dtype
The dtype for the real input vector (or real output vector if doing an
IFFT). np.float32 or np.float64 I think in all cases.
delta_t : float (default: 1)
delta_t of the real vector. If not given this will be set to 1, and we
will assume it is not needed in the returned TimeSeries/FrequencySeries
ifft : boolean (default: False)
By default will use the FFT class, set to true to use IFFT.
uid : int (default: 0)
Provide a unique identifier. This is used to provide a separate set
of memory in the cache, for instance if calling this from different
codes.
"""
from pycbc.types import FrequencySeries, TimeSeries, zeros
from pycbc.types import complex_same_precision_as

npoints_freq = npoints_time // 2 + 1
delta_f_tmp = 1.0 / (npoints_time * delta_t)
vec = TimeSeries(
zeros(
npoints_time,
dtype=dtype
),
delta_t=delta_t,
copy=False
)
vectilde = FrequencySeries(
zeros(
npoints_freq,
dtype=complex_same_precision_as(vec)
),
delta_f=delta_f_tmp,
copy=False
)
if ifft:
fft_class = IFFT(vectilde, vec)
invec = vectilde
outvec = vec
else:
fft_class = FFT(vec, vectilde)
invec = vec
outvec = vectilde

return invec, outvec, fft_class


def execute_cached_fft(invec_data, normalize_by_rate=True, ifft=False,
uid=0):
""" Executes a cached FFT

Parameters
-----------
invec_data : Array
Array which will be used as input when fft_class is executed.
normalize_by_rate : boolean (optional, default:False)
If True, then normalize by delta_t (for an FFT) or delta_f (for an
IFFT).
ifft : boolean (optional, default:False)
If true assume this is an IFFT and multiply by delta_f not delta_t.
Will do nothing if normalize_by_rate is False.
uid : int (default: 0)
Provide a unique identifier. This is used to provide a separate set
of memory in the cache, for instance if calling this from different
codes.
"""
from pycbc.types import real_same_precision_as
if ifft:
npoints_time = (len(invec_data) - 1) * 2
else:
npoints_time = len(invec_data)

try:
delta_t = invec_data.delta_t
except AttributeError:
if not normalize_by_rate:
# Don't need this
delta_t = 1
else:
raise

dtype = real_same_precision_as(invec_data)

invec, outvec, fft_class = create_memory_and_engine_for_class_based_fft(
npoints_time,
dtype,
delta_t=delta_t,
ifft=ifft,
uid=uid
)

if invec_data is not None:
invec._data[:] = invec_data._data[:]
fft_class.execute()
if normalize_by_rate:
if ifft:
outvec._data *= invec._delta_f
else:
outvec._data *= invec._delta_t
return outvec


def execute_cached_ifft(*args, **kwargs):
""" Executes a cached IFFT

Parameters
-----------
invec_data : Array
Array which will be used as input when fft_class is executed.
normalize_by_rate : boolean (optional, default:False)
If True, then normalize by delta_t (for an FFT) or delta_f (for an
IFFT).
uid : int (default: 0)
Provide a unique identifier. This is used to provide a separate set
of memory in the cache, for instance if calling this from different
codes.
"""
return execute_cached_fft(*args, **kwargs, ifft=True)


class StrainBuffer(pycbc.frame.DataBuffer):
def __init__(self, frame_src, channel_name, start_time,
max_buffer=512,
Expand Down Expand Up @@ -1443,8 +1581,6 @@ def __init__(self, frame_src, channel_name, start_time,
self.add_hard_count()
self.taper_immediate_strain = True

# Caches for FFTs to use class based API
self.fft_cache = {}

@property
def start_time(self):
Expand Down Expand Up @@ -1515,28 +1651,27 @@ def create_memory_for_overwhitened_data(self, npoints_time):
npoints_time,
self.strain.dtype,
delta_t=self.strain.delta_t,
ifft=False
ifft=False,
uid=87123876
)

whitened_data_ifft_outs = create_memory_and_engine_for_class_based_fft(
npoints_time,
self.strain.dtype,
delta_t=self.strain.delta_t,
ifft=True
ifft=True,
uid=264654684
)

trimmed_data_fft_outs = create_memory_and_engine_for_class_based_fft(
npoints_time - (2 * self.reduced_pad),
self.strain.dtype,
delta_t=self.strain.delta_t,
ifft=False
ifft=False,
uid=712394716
)

self.fft_cache[npoints_time] = (
data_fft_outs,
whitened_data_ifft_outs,
trimmed_data_fft_outs
)
return (data_fft_outs, whitened_data_ifft_outs, trimmed_data_fft_outs)

def overwhitened_data(self, delta_f):
""" Return overwhitened data
Expand All @@ -1557,14 +1692,9 @@ def overwhitened_data(self, delta_f):
e = len(self.strain)
s = int(e - buffer_length * self.sample_rate - self.reduced_pad * 2)
npoints_time = e - s
if npoints_time not in self.fft_cache:
self.create_memory_for_overwhitened_data(npoints_time)
fft_memory = self.fft_cache[npoints_time]

vec, fseries, fft_class = self.fft_cache[npoints_time][0]
vec._data[:] = self.strain[s:e]
fft_class.execute() # vec -> FFT -> fseries
fseries._data *= vec._delta_t

# FFT the contents of self.strain[s:e] into fseries
fseries = execute_cached_fft(self.strain[s:e], uid=85437862)
fseries._epoch = self.strain._epoch + s*self.strain.delta_t

# we haven't calculated a resample psd for this delta_f
Expand All @@ -1589,21 +1719,17 @@ def overwhitened_data(self, delta_f):
# trim ends of strain
if self.reduced_pad != 0:
npoints_time = e - s
vectilde, overwhite, fft_class = self.fft_cache[npoints_time][1]
vectilde._data[:] = fseries._data[:]
fft_class.execute() # vectilde -> IFFT -> overwhite
overwhite._data *= vectilde._delta_f
# IFFT the contents of fseries into overwhite
overwhite = execute_cached_ifft(fseries, uid=98961342)

overwhite2 = overwhite[self.reduced_pad:len(overwhite)-self.reduced_pad]
taper_window = self.trim_padding / 2.0 / overwhite.sample_rate
gate_params = [(overwhite2.start_time, 0., taper_window),
(overwhite2.end_time, 0., taper_window)]
gate_data(overwhite2, gate_params)

vec, fseries_trimmed, fft_class = self.fft_cache[npoints_time][2]
vec._data[:] = overwhite2._data[:]
fft_class.execute() # vec -> FFT -> fseries_trimmed
fseries_trimmed._data *= vec._delta_t
# FFT the contents of overwhite2 into fseries_trimmed
fseries_trimmed = execute_cached_fft(overwhite2, uid=91237641)

fseries_trimmed.start_time = fseries.start_time + self.reduced_pad * self.strain.delta_t
else:
Expand Down