Skip to content

Commit

Permalink
Unify QuantileOutput and DistributionOutput (#3093)
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur committed Jan 10, 2024
1 parent 7835302 commit c99dafa
Show file tree
Hide file tree
Showing 47 changed files with 639 additions and 550 deletions.
15 changes: 15 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,18 @@ wrap-descriptions = 79
ignore_missing_imports = true
allow_redefinition = true
follow_imports = "silent"

[tool.isort]
known_first_party = "gluonts"
known_third_party = [
"mxnet",
"numpy",
"pandas",
"pytest",
"scipy",
"tqdm",
"torch",
"lightning",
]
line_length = 79
profile = "black"
24 changes: 14 additions & 10 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ def log_once(msg):
LOG_CACHE.add(msg)


# different deep learning frameworks generate predictions and the tensor to
# numpy conversion differently, use a dispatching function to prevent needing
# a ForecastGenerators for each framework
# Convert tensors from different deep learning frameworks to numpy. We use a dispatching
# function to prevent needing a ForecastGenerators for each framework.
@singledispatch
def predict_to_numpy(prediction_net, kwargs) -> np.ndarray:
def to_numpy(x) -> np.ndarray:
raise NotImplementedError


Expand Down Expand Up @@ -116,17 +115,22 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
outputs = predict_to_numpy(prediction_net, inputs)
if output_transform is not None:
outputs = output_transform(batch, outputs)
(outputs,), loc, scale = prediction_net(*inputs.values())
outputs = to_numpy(outputs)
if scale is not None:
outputs = outputs * to_numpy(scale[..., None])
if loc is not None:
outputs = outputs + to_numpy(loc[..., None])

if output_transform is not None:
log_once(OUTPUT_TRANSFORM_NOT_SUPPORTED_MSG)
if num_samples:
log_once(NOT_SAMPLE_BASED_MSG)

i = -1
for i, output in enumerate(outputs):
yield QuantileForecast(
output,
output.T,
start_date=batch[FieldName.FORECAST_START][i],
item_id=batch[FieldName.ITEM_ID][i]
if FieldName.ITEM_ID in batch
Expand All @@ -153,14 +157,14 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
outputs = predict_to_numpy(prediction_net, inputs)
outputs = to_numpy(prediction_net(*inputs.values()))
if output_transform is not None:
outputs = output_transform(batch, outputs)
if num_samples:
num_collected_samples = outputs[0].shape[0]
collected_samples = [outputs]
while num_collected_samples < num_samples:
outputs = predict_to_numpy(prediction_net, inputs)
outputs = to_numpy(prediction_net(*inputs.values()))
if output_transform is not None:
outputs = output_transform(batch, outputs)
collected_samples.append(outputs)
Expand Down
13 changes: 9 additions & 4 deletions src/gluonts/mx/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from gluonts.model.forecast_generator import (
ForecastGenerator,
SampleForecastGenerator,
predict_to_numpy,
to_numpy,
)
from gluonts.model.predictor import OutputTransform, Predictor
from gluonts.mx.batchify import batchify
Expand All @@ -43,9 +43,14 @@
from gluonts.transform import Transformation


@predict_to_numpy.register(mx.gluon.Block)
def _(prediction_net: mx.gluon.Block, kwargs) -> np.ndarray:
return prediction_net(*kwargs.values()).asnumpy()
@to_numpy.register(mx.nd.NDArray)
def _(x: mx.nd.NDArray) -> np.ndarray:
return x.asnumpy()


@to_numpy.register(mx.sym.Symbol)
def _(x: mx.sym.Symbol) -> np.ndarray:
return x.asnumpy()


class GluonPredictor(Predictor):
Expand Down
6 changes: 3 additions & 3 deletions src/gluonts/mx/model/seq2seq/_forking_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def hybrid_forward(
future_feat_dynamic: Tensor,
feat_static_cat: Tensor,
past_observed_values: Tensor,
) -> Tensor:
) -> Tuple[Tuple[Tensor, ...], Tensor, Tensor]:
"""
Parameters
----------
Expand Down Expand Up @@ -332,8 +332,8 @@ def hybrid_forward(
axis=1,
)

# shape: (num_test_ts, num_quantiles, prediction_length)
return fcst_output.swapaxes(2, 1)
# shape: (num_test_ts, prediction_length, num_quantiles)
return (fcst_output,), None, None


class ForkingSeq2SeqDistributionPredictionNetwork(ForkingSeq2SeqNetworkBase):
Expand Down
8 changes: 5 additions & 3 deletions src/gluonts/mx/model/seq2seq/_seq2seq_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Tuple

import mxnet as mx

from gluonts.core.component import validated
Expand Down Expand Up @@ -157,7 +159,7 @@ def hybrid_forward(
feat_static_cat: Tensor,
past_feat_dynamic_real: Tensor,
future_feat_dynamic_real: Tensor,
) -> Tensor:
) -> Tuple[Tuple[Tensor, ...], Tensor, Tensor]:
"""
Parameters
Expand Down Expand Up @@ -185,6 +187,6 @@ def hybrid_forward(
past_feat_dynamic_real=past_feat_dynamic_real,
future_feat_dynamic_real=future_feat_dynamic_real,
)
predictions = self.quantile_proj(scaled_decoder_output).swapaxes(2, 1)
predictions = self.quantile_proj(scaled_decoder_output)

return predictions
return (predictions,), None, None
7 changes: 3 additions & 4 deletions src/gluonts/mx/model/tft/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import List, Type
from typing import List, Tuple, Type

import numpy as np
from mxnet.gluon import HybridBlock, nn
Expand Down Expand Up @@ -487,7 +487,7 @@ def hybrid_forward(
feat_dynamic_cat: Tensor,
feat_static_real: Tensor,
feat_static_cat: Tensor,
):
) -> Tuple[Tuple[Tensor, ...], Tensor, Tensor]:
(
past_covariates,
future_covariates,
Expand Down Expand Up @@ -515,5 +515,4 @@ def hybrid_forward(
)

preds = self._postprocess(F, preds, offset, scale)
preds = F.swapaxes(preds, dim1=1, dim2=2)
return preds
return (preds,), None, None
4 changes: 4 additions & 0 deletions src/gluonts/torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
)
from .isqf import ISQF, ISQFOutput
from .negative_binomial import NegativeBinomialOutput
from .output import Output
from .piecewise_linear import PiecewiseLinear, PiecewiseLinearOutput
from .quantile_output import QuantileOutput
from .spliced_binned_pareto import (
SplicedBinnedPareto,
SplicedBinnedParetoOutput,
Expand All @@ -54,9 +56,11 @@
"LaplaceOutput",
"NegativeBinomialOutput",
"NormalOutput",
"Output",
"PiecewiseLinear",
"PiecewiseLinearOutput",
"PoissonOutput",
"QuantileOutput",
"SplicedBinnedPareto",
"SplicedBinnedParetoOutput",
"StudentTOutput",
Expand Down
109 changes: 27 additions & 82 deletions src/gluonts/torch/distributions/distribution_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,92 +11,28 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Callable, Dict, Optional, Tuple, Type
from typing import Dict, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import (
Beta,
Distribution,
Gamma,
Laplace,
Normal,
Poisson,
Laplace,
)

from gluonts.core.component import validated
from gluonts.model.forecast_generator import (
DistributionForecastGenerator,
ForecastGenerator,
)
from gluonts.torch.distributions import AffineTransformed
from gluonts.torch.modules.lambda_layer import LambdaLayer


class PtArgProj(nn.Module):
r"""
A PyTorch module that can be used to project from a dense layer
to PyTorch distribution arguments.
Parameters
----------
in_features
Size of the incoming features.
dim_args
Dictionary with string key and int value
dimension of each arguments that will be passed to the domain
map, the names are not used.
domain_map
Function returning a tuple containing one tensor
a function or a nn.Module. This will be called with num_args
arguments and should return a tuple of outputs that will be
used when calling the distribution constructor.
"""

def __init__(
self,
in_features: int,
args_dim: Dict[str, int],
domain_map: Callable[..., Tuple[torch.Tensor]],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.args_dim = args_dim
self.proj = nn.ModuleList(
[nn.Linear(in_features, dim) for dim in args_dim.values()]
)
self.domain_map = domain_map

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
params_unbounded = [proj(x) for proj in self.proj]

return self.domain_map(*params_unbounded)


class Output:
"""
Class to connect a network to some output.
"""

in_features: int
args_dim: Dict[str, int]
_dtype: Type = np.float32

@property
def dtype(self):
return self._dtype

@dtype.setter
def dtype(self, dtype: Type):
self._dtype = dtype

def get_args_proj(self, in_features: int) -> nn.Module:
return PtArgProj(
in_features=in_features,
args_dim=self.args_dim,
domain_map=LambdaLayer(self.domain_map),
)

def domain_map(self, *args: torch.Tensor):
raise NotImplementedError()
from .output import Output


class DistributionOutput(Output):
Expand All @@ -107,8 +43,8 @@ class DistributionOutput(Output):
distr_cls: type

@validated()
def __init__(self) -> None:
pass
def __init__(self, beta: float = 0.0) -> None:
self.beta = beta

def _base_distribution(self, distr_args):
return self.distr_cls(*distr_args)
Expand Down Expand Up @@ -140,6 +76,20 @@ def distribution(
else:
return AffineTransformed(distr, loc=loc, scale=scale)

def loss(
self,
target: torch.Tensor,
distr_args: Tuple[torch.Tensor, ...],
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
distribution = self.distribution(distr_args, loc=loc, scale=scale)
nll = -distribution.log_prob(target)
if self.beta > 0.0:
variance = distribution.variance
nll = nll * (variance.detach() ** self.beta)
return nll

@property
def event_shape(self) -> Tuple:
r"""
Expand All @@ -156,15 +106,6 @@ def event_dim(self) -> int:
"""
return len(self.event_shape)

@property
def value_in_support(self) -> float:
r"""
A float that will have a valid numeric value when computing the
log-loss of the corresponding distribution. By default 0.0.
This value will be used when padding data series.
"""
return 0.0

def domain_map(self, *args: torch.Tensor):
r"""
Converts arguments to the right shape and domain. The domain depends
Expand All @@ -174,6 +115,10 @@ def domain_map(self, *args: torch.Tensor):
"""
raise NotImplementedError()

@property
def forecast_generator(self) -> ForecastGenerator:
return DistributionForecastGenerator(self)


class NormalOutput(DistributionOutput):
args_dim: Dict[str, int] = {"loc": 1, "scale": 1}
Expand Down
23 changes: 18 additions & 5 deletions src/gluonts/torch/distributions/implicit_quantile_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,27 @@ def get_args_proj(self, in_features: int) -> nn.Module:
def domain_map(cls, *args):
return args

def distribution(self, distr_args, loc=0, scale=None) -> Distribution:
def distribution(
self, distr_args, loc=0, scale=None
) -> ImplicitQuantileNetwork:
(outputs, taus) = distr_args

if scale is None:
return self.distr_cls(outputs=outputs, taus=taus)
else:
return self.distr_cls(outputs=loc + outputs * scale, taus=taus)
if scale is not None:
outputs = outputs * scale
if loc is not None:
outputs = outputs + loc
return self.distr_cls(outputs=outputs, taus=taus)

@property
def event_shape(self):
return ()

def loss(
self,
target: torch.Tensor,
distr_args: Tuple[torch.Tensor, ...],
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
distribution = self.distribution(distr_args, loc=loc, scale=scale)
return distribution.quantile_loss(target)
2 changes: 1 addition & 1 deletion src/gluonts/torch/distributions/isqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def __init__(
) -> None:
# ISQF reduces to IQF when num_pieces = 1

super().__init__(self)
super().__init__()

assert (
isinstance(num_pieces, int) and num_pieces > 0
Expand Down
Loading

0 comments on commit c99dafa

Please sign in to comment.