Skip to content

Commit

Permalink
[AIRFLOW-491] Add feature to pass extra api configs to BQ Hook (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
xnuinside authored and Alice Berard committed Jan 3, 2019
1 parent 5595c1e commit aae770c
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 99 deletions.
233 changes: 148 additions & 85 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import time
from builtins import range
from copy import deepcopy

from past.builtins import basestring

Expand Down Expand Up @@ -195,10 +196,19 @@ class BigQueryBaseCursor(LoggingMixin):
PEP 249 cursor isn't needed.
"""

def __init__(self, service, project_id, use_legacy_sql=True):
def __init__(self,
service,
project_id,
use_legacy_sql=True,
api_resource_configs=None):

self.service = service
self.project_id = project_id
self.use_legacy_sql = use_legacy_sql
if api_resource_configs:
_validate_value("api_resource_configs", api_resource_configs, dict)
self.api_resource_configs = api_resource_configs \
if api_resource_configs else {}
self.running_job_id = None

def create_empty_table(self,
Expand Down Expand Up @@ -238,8 +248,7 @@ def create_empty_table(self,
:return:
"""
if time_partitioning is None:
time_partitioning = dict()

project_id = project_id if project_id is not None else self.project_id

table_resource = {
Expand Down Expand Up @@ -473,11 +482,11 @@ def create_external_table(self,
def run_query(self,
bql=None,
sql=None,
destination_dataset_table=False,
destination_dataset_table=None,
write_disposition='WRITE_EMPTY',
allow_large_results=False,
flatten_results=None,
udf_config=False,
udf_config=None,
use_legacy_sql=None,
maximum_billing_tier=None,
maximum_bytes_billed=None,
Expand All @@ -486,7 +495,8 @@ def run_query(self,
labels=None,
schema_update_options=(),
priority='INTERACTIVE',
time_partitioning=None):
time_partitioning=None,
api_resource_configs=None):
"""
Executes a BigQuery SQL query. Optionally persists results in a BigQuery
table. See here:
Expand Down Expand Up @@ -518,6 +528,13 @@ def run_query(self,
:param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
If `None`, defaults to `self.use_legacy_sql`.
:type use_legacy_sql: boolean
:param api_resource_configs: a dictionary that contain params
'configuration' applied for Google BigQuery Jobs API:
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs
for example, {'query': {'useQueryCache': False}}. You could use it
if you need to provide some params that are not supported by the
BigQueryHook like args.
:type api_resource_configs: dict
:type udf_config: list
:param maximum_billing_tier: Positive integer that serves as a
multiplier of the basic price.
Expand Down Expand Up @@ -550,12 +567,22 @@ def run_query(self,
:type time_partitioning: dict
"""
if not api_resource_configs:
api_resource_configs = self.api_resource_configs
else:
_validate_value('api_resource_configs',
api_resource_configs, dict)
configuration = deepcopy(api_resource_configs)
if 'query' not in configuration:
configuration['query'] = {}

else:
_validate_value("api_resource_configs['query']",
configuration['query'], dict)

# TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
if time_partitioning is None:
time_partitioning = {}
sql = bql if sql is None else sql

# TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
if bql:
import warnings
warnings.warn('Deprecated parameter `bql` used in '
Expand All @@ -566,95 +593,109 @@ def run_query(self,
'Airflow.',
category=DeprecationWarning)

if sql is None:
raise TypeError('`BigQueryBaseCursor.run_query` missing 1 required '
'positional argument: `sql`')
if sql is None and not configuration['query'].get('query', None):
raise TypeError('`BigQueryBaseCursor.run_query` '
'missing 1 required positional argument: `sql`')

# BigQuery also allows you to define how you want a table's schema to change
# as a side effect of a query job
# for more details:
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions

allowed_schema_update_options = [
'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"
]
if not set(allowed_schema_update_options).issuperset(
set(schema_update_options)):
raise ValueError(
"{0} contains invalid schema update options. "
"Please only use one or more of the following options: {1}"
.format(schema_update_options, allowed_schema_update_options))

if use_legacy_sql is None:
use_legacy_sql = self.use_legacy_sql
if not set(allowed_schema_update_options
).issuperset(set(schema_update_options)):
raise ValueError("{0} contains invalid schema update options. "
"Please only use one or more of the following "
"options: {1}"
.format(schema_update_options,
allowed_schema_update_options))

configuration = {
'query': {
'query': sql,
'useLegacySql': use_legacy_sql,
'maximumBillingTier': maximum_billing_tier,
'maximumBytesBilled': maximum_bytes_billed,
'priority': priority
}
}
if schema_update_options:
if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
raise ValueError("schema_update_options is only "
"allowed if write_disposition is "
"'WRITE_APPEND' or 'WRITE_TRUNCATE'.")

if destination_dataset_table:
if '.' not in destination_dataset_table:
raise ValueError(
'Expected destination_dataset_table name in the format of '
'<dataset>.<table>. Got: {}'.format(
destination_dataset_table))
destination_project, destination_dataset, destination_table = \
_split_tablename(table_input=destination_dataset_table,
default_project_id=self.project_id)
configuration['query'].update({
'allowLargeResults': allow_large_results,
'flattenResults': flatten_results,
'writeDisposition': write_disposition,
'createDisposition': create_disposition,
'destinationTable': {
'projectId': destination_project,
'datasetId': destination_dataset,
'tableId': destination_table,
}
})
if udf_config:
if not isinstance(udf_config, list):
raise TypeError("udf_config argument must have a type 'list'"
" not {}".format(type(udf_config)))
configuration['query'].update({
'userDefinedFunctionResources': udf_config
})

if query_params:
if self.use_legacy_sql:
raise ValueError("Query parameters are not allowed when using "
"legacy SQL")
else:
configuration['query']['queryParameters'] = query_params
destination_dataset_table = {
'projectId': destination_project,
'datasetId': destination_dataset,
'tableId': destination_table,
}

if labels:
configuration['labels'] = labels
query_param_list = [
(sql, 'query', None, str),
(priority, 'priority', 'INTERACTIVE', str),
(use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool),
(query_params, 'queryParameters', None, dict),
(udf_config, 'userDefinedFunctionResources', None, list),
(maximum_billing_tier, 'maximumBillingTier', None, int),
(maximum_bytes_billed, 'maximumBytesBilled', None, float),
(time_partitioning, 'timePartitioning', {}, dict),
(schema_update_options, 'schemaUpdateOptions', None, tuple),
(destination_dataset_table, 'destinationTable', None, dict)
]

time_partitioning = _cleanse_time_partitioning(
destination_dataset_table,
time_partitioning
)
if time_partitioning:
configuration['query'].update({
'timePartitioning': time_partitioning
})
for param_tuple in query_param_list:

if schema_update_options:
if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
raise ValueError("schema_update_options is only "
"allowed if write_disposition is "
"'WRITE_APPEND' or 'WRITE_TRUNCATE'.")
else:
self.log.info(
"Adding experimental "
"'schemaUpdateOptions': {0}".format(schema_update_options))
configuration['query'][
'schemaUpdateOptions'] = schema_update_options
param, param_name, param_default, param_type = param_tuple

if param_name not in configuration['query'] and param in [None, {}, ()]:
if param_name == 'timePartitioning':
param_default = _cleanse_time_partitioning(
destination_dataset_table, time_partitioning)
param = param_default

if param not in [None, {}, ()]:
_api_resource_configs_duplication_check(
param_name, param, configuration['query'])

configuration['query'][param_name] = param

# check valid type of provided param,
# it last step because we can get param from 2 sources,
# and first of all need to find it

_validate_value(param_name, configuration['query'][param_name],
param_type)

if param_name == 'schemaUpdateOptions' and param:
self.log.info("Adding experimental 'schemaUpdateOptions': "
"{0}".format(schema_update_options))

if param_name == 'destinationTable':
for key in ['projectId', 'datasetId', 'tableId']:
if key not in configuration['query']['destinationTable']:
raise ValueError(
"Not correct 'destinationTable' in "
"api_resource_configs. 'destinationTable' "
"must be a dict with {'projectId':'', "
"'datasetId':'', 'tableId':''}")

configuration['query'].update({
'allowLargeResults': allow_large_results,
'flattenResults': flatten_results,
'writeDisposition': write_disposition,
'createDisposition': create_disposition,
})

if 'useLegacySql' in configuration['query'] and \
'queryParameters' in configuration['query']:
raise ValueError("Query parameters are not allowed "
"when using legacy SQL")

if labels:
_api_resource_configs_duplication_check(
'labels', labels, configuration)
configuration['labels'] = labels

return self.run_with_configuration(configuration)

Expand Down Expand Up @@ -888,8 +929,7 @@ def run_load(self,
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat
if src_fmt_configs is None:
src_fmt_configs = {}
if time_partitioning is None:
time_partitioning = {}

source_format = source_format.upper()
allowed_formats = [
"CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS",
Expand Down Expand Up @@ -1167,10 +1207,6 @@ def run_table_delete(self, deletion_dataset_table,
:type ignore_if_missing: boolean
:return:
"""
if '.' not in deletion_dataset_table:
raise ValueError(
'Expected deletion_dataset_table name in the format of '
'<dataset>.<table>. Got: {}'.format(deletion_dataset_table))
deletion_project, deletion_dataset, deletion_table = \
_split_tablename(table_input=deletion_dataset_table,
default_project_id=self.project_id)
Expand Down Expand Up @@ -1536,6 +1572,12 @@ def _bq_cast(string_field, bq_type):


def _split_tablename(table_input, default_project_id, var_name=None):

if '.' not in table_input:
raise ValueError(
'Expected deletion_dataset_table name in the format of '
'<dataset>.<table>. Got: {}'.format(table_input))

if not default_project_id:
raise ValueError("INTERNAL: No default project is specified")

Expand Down Expand Up @@ -1597,6 +1639,10 @@ def var_print(var_name):

def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
# if it is a partitioned table ($ is in the table name) add partition load option

if time_partitioning_in is None:
time_partitioning_in = {}

time_partitioning_out = {}
if destination_dataset_table and '$' in destination_dataset_table:
if time_partitioning_in.get('field'):
Expand All @@ -1607,3 +1653,20 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):

time_partitioning_out.update(time_partitioning_in)
return time_partitioning_out


def _validate_value(key, value, expected_type):
""" function to check expected type and raise
error if type is not correct """
if not isinstance(value, expected_type):
raise TypeError("{} argument must have a type {} not {}".format(
key, expected_type, type(value)))


def _api_resource_configs_duplication_check(key, value, config_dict):
if key in config_dict and value != config_dict[key]:
raise ValueError("Values of {param_name} param are duplicated. "
"`api_resource_configs` contained {param_name} param "
"in `query` config and {param_name} was also provided "
"with arg to run_query() method. Please remove duplicates."
.format(param_name=key))
18 changes: 15 additions & 3 deletions airflow/contrib/operators/bigquery_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class BigQueryOperator(BaseOperator):
(without incurring a charge). If unspecified, this will be
set to your project default.
:type maximum_bytes_billed: float
:param api_resource_configs: a dictionary that contain params
'configuration' applied for Google BigQuery Jobs API:
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs
for example, {'query': {'useQueryCache': False}}. You could use it
if you need to provide some params that are not supported by BigQueryOperator
like args.
:type api_resource_configs: dict
:param schema_update_options: Allows the schema of the destination
table to be updated as a side effect of the load job.
:type schema_update_options: tuple
Expand Down Expand Up @@ -118,7 +125,8 @@ def __init__(self,
query_params=None,
labels=None,
priority='INTERACTIVE',
time_partitioning={},
time_partitioning=None,
api_resource_configs=None,
*args,
**kwargs):
super(BigQueryOperator, self).__init__(*args, **kwargs)
Expand All @@ -140,7 +148,10 @@ def __init__(self,
self.labels = labels
self.bq_cursor = None
self.priority = priority
self.time_partitioning = time_partitioning
if time_partitioning is None:
self.time_partitioning = {}
if api_resource_configs is None:
self.api_resource_configs = {}

# TODO remove `bql` in Airflow 2.0
if self.bql:
Expand Down Expand Up @@ -179,7 +190,8 @@ def execute(self, context):
labels=self.labels,
schema_update_options=self.schema_update_options,
priority=self.priority,
time_partitioning=self.time_partitioning
time_partitioning=self.time_partitioning,
api_resource_configs=self.api_resource_configs,
)

def on_kill(self):
Expand Down
Loading

0 comments on commit aae770c

Please sign in to comment.