airflow settings 源码

  • 2022-10-20
  • 浏览 (583)

airflow settings 代码

文件路径:/airflow/settings.py

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import atexit
import functools
import json
import logging
import os
import sys
import warnings
from typing import TYPE_CHECKING, Callable

import pendulum
import sqlalchemy
from sqlalchemy import create_engine, exc
from sqlalchemy.engine import Engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm.session import Session as SASession
from sqlalchemy.pool import NullPool

from airflow.configuration import AIRFLOW_HOME, WEBSERVER_CONFIG, conf  # NOQA F401
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.executors import executor_constants
from airflow.logging_config import configure_logging
from airflow.utils.orm_event_handlers import setup_event_handlers

if TYPE_CHECKING:
    from airflow.www.utils import UIAlert

log = logging.getLogger(__name__)


TIMEZONE = pendulum.tz.timezone('UTC')
try:
    tz = conf.get_mandatory_value("core", "default_timezone")
    if tz == "system":
        TIMEZONE = pendulum.tz.local_timezone()
    else:
        TIMEZONE = pendulum.tz.timezone(tz)
except Exception:
    pass
log.info("Configured default timezone %s", TIMEZONE)


HEADER = '\n'.join(
    [
        r'  ____________       _____________',
        r' ____    |__( )_________  __/__  /________      __',
        r'____  /| |_  /__  ___/_  /_ __  /_  __ \_ | /| / /',
        r'___  ___ |  / _  /   _  __/ _  / / /_/ /_ |/ |/ /',
        r' _/_/  |_/_/  /_/    /_/    /_/  \____/____/|__/',
    ]
)

LOGGING_LEVEL = logging.INFO

# the prefix to append to gunicorn worker processes after init
GUNICORN_WORKER_READY_PREFIX = "[ready] "

LOG_FORMAT = conf.get('logging', 'log_format')
SIMPLE_LOG_FORMAT = conf.get('logging', 'simple_log_format')

SQL_ALCHEMY_CONN: str | None = None
PLUGINS_FOLDER: str | None = None
LOGGING_CLASS_PATH: str | None = None
DONOT_MODIFY_HANDLERS: bool | None = None
DAGS_FOLDER: str = os.path.expanduser(conf.get_mandatory_value('core', 'DAGS_FOLDER'))

engine: Engine
Session: Callable[..., SASession]

# The JSON library to use for DAG Serialization and De-Serialization
json = json

# Dictionary containing State and colors associated to each state to
# display on the Webserver
STATE_COLORS = {
    "deferred": "mediumpurple",
    "failed": "red",
    "queued": "gray",
    "removed": "lightgrey",
    "restarting": "violet",
    "running": "lime",
    "scheduled": "tan",
    "shutdown": "blue",
    "skipped": "hotpink",
    "success": "green",
    "up_for_reschedule": "turquoise",
    "up_for_retry": "gold",
    "upstream_failed": "orange",
}


@functools.lru_cache(maxsize=None)
def _get_rich_console(file):
    # Delay imports until we need it
    import rich.console

    return rich.console.Console(file=file)


def custom_show_warning(message, category, filename, lineno, file=None, line=None):
    """Custom function to print rich and visible warnings"""
    # Delay imports until we need it
    from rich.markup import escape

    msg = f"[bold]{line}" if line else f"[bold][yellow]{filename}:{lineno}"
    msg += f" {category.__name__}[/bold]: {escape(str(message))}[/yellow]"
    write_console = _get_rich_console(file or sys.stderr)
    write_console.print(msg, soft_wrap=True)


def replace_showwarning(replacement):
    """Replace ``warnings.showwarning``, returning the original.

    This is useful since we want to "reset" the ``showwarning`` hook on exit to
    avoid lazy-loading issues. If a warning is emitted after Python cleaned up
    the import system, we would no longer be able to import ``rich``.
    """
    original = warnings.showwarning
    warnings.showwarning = replacement
    return original


original_show_warning = replace_showwarning(custom_show_warning)
atexit.register(functools.partial(replace_showwarning, original_show_warning))


def task_policy(task) -> None:
    """
    This policy setting allows altering tasks after they are loaded in
    the DagBag. It allows administrator to rewire some task's parameters.
    Alternatively you can raise ``AirflowClusterPolicyViolation`` exception
    to stop DAG from being executed.

    To define policy, add a ``airflow_local_settings`` module
    to your PYTHONPATH that defines this ``task_policy`` function.

    Here are a few examples of how this can be useful:

    * You could enforce a specific queue (say the ``spark`` queue)
        for tasks using the ``SparkOperator`` to make sure that these
        tasks get wired to the right workers
    * You could enforce a task timeout policy, making sure that no tasks run
        for more than 48 hours

    :param task: task to be mutated
    """


def dag_policy(dag) -> None:
    """
    This policy setting allows altering DAGs after they are loaded in
    the DagBag. It allows administrator to rewire some DAG's parameters.
    Alternatively you can raise ``AirflowClusterPolicyViolation`` exception
    to stop DAG from being executed.

    To define policy, add a ``airflow_local_settings`` module
    to your PYTHONPATH that defines this ``dag_policy`` function.

    Here are a few examples of how this can be useful:

    * You could enforce default user for DAGs
    * Check if every DAG has configured tags

    :param dag: dag to be mutated
    """


def task_instance_mutation_hook(task_instance):
    """
    This setting allows altering task instances before they are queued by
    the Airflow scheduler.

    To define task_instance_mutation_hook, add a ``airflow_local_settings`` module
    to your PYTHONPATH that defines this ``task_instance_mutation_hook`` function.

    This could be used, for instance, to modify the task instance during retries.

    :param task_instance: task instance to be mutated
    """


task_instance_mutation_hook.is_noop = True  # type: ignore


def pod_mutation_hook(pod):
    """
    This setting allows altering ``kubernetes.client.models.V1Pod`` object
    before they are passed to the Kubernetes client for scheduling.

    To define a pod mutation hook, add a ``airflow_local_settings`` module
    to your PYTHONPATH that defines this ``pod_mutation_hook`` function.
    It receives a ``Pod`` object and can alter it where needed.

    This could be used, for instance, to add sidecar or init containers
    to every worker pod launched by KubernetesExecutor or KubernetesPodOperator.
    """


def get_airflow_context_vars(context):
    """
    This setting allows getting the airflow context vars, which are key value pairs.
    They are then injected to default airflow context vars, which in the end are
    available as environment variables when running tasks
    dag_id, task_id, execution_date, dag_run_id, try_number are reserved keys.
    To define it, add a ``airflow_local_settings`` module
    to your PYTHONPATH that defines this ``get_airflow_context_vars`` function.

    :param context: The context for the task_instance of interest.
    """
    return {}


def get_dagbag_import_timeout(dag_file_path: str) -> int | float:
    """
    This setting allows for dynamic control of the DAG file parsing timeout based on the DAG file path.

    It is useful when there are a few DAG files requiring longer parsing times, while others do not.
    You can control them separately instead of having one value for all DAG files.

    If the return value is less than or equal to 0, it means no timeout during the DAG parsing.
    """
    return conf.getfloat('core', 'DAGBAG_IMPORT_TIMEOUT')


def configure_vars():
    """Configure Global Variables from airflow.cfg"""
    global SQL_ALCHEMY_CONN
    global DAGS_FOLDER
    global PLUGINS_FOLDER
    global DONOT_MODIFY_HANDLERS
    SQL_ALCHEMY_CONN = conf.get('database', 'SQL_ALCHEMY_CONN')
    DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))

    PLUGINS_FOLDER = conf.get('core', 'plugins_folder', fallback=os.path.join(AIRFLOW_HOME, 'plugins'))

    # If donot_modify_handlers=True, we do not modify logging handlers in task_run command
    # If the flag is set to False, we remove all handlers from the root logger
    # and add all handlers from 'airflow.task' logger to the root Logger. This is done
    # to get all the logs from the print & log statements in the DAG files before a task is run
    # The handlers are restored after the task completes execution.
    DONOT_MODIFY_HANDLERS = conf.getboolean('logging', 'donot_modify_handlers', fallback=False)


def configure_orm(disable_connection_pool=False, pool_class=None):
    """Configure ORM using SQLAlchemy"""
    from airflow.utils.log.secrets_masker import mask_secret

    log.debug("Setting up DB connection pool (PID %s)", os.getpid())
    global engine
    global Session
    engine_args = prepare_engine_args(disable_connection_pool, pool_class)

    if conf.has_option('database', 'sql_alchemy_connect_args'):
        connect_args = conf.getimport('database', 'sql_alchemy_connect_args')
    else:
        connect_args = {}

    engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args)

    mask_secret(engine.url.password)

    setup_event_handlers(engine)

    Session = scoped_session(
        sessionmaker(
            autocommit=False,
            autoflush=False,
            bind=engine,
            expire_on_commit=False,
        )
    )
    if engine.dialect.name == 'mssql':
        session = Session()
        try:
            result = session.execute(
                sqlalchemy.text(
                    'SELECT is_read_committed_snapshot_on FROM sys.databases WHERE name=:database_name'
                ),
                params={"database_name": engine.url.database},
            )
            data = result.fetchone()[0]
            if data != 1:
                log.critical("MSSQL database MUST have READ_COMMITTED_SNAPSHOT enabled.")
                log.critical("The database %s has it disabled.", engine.url.database)
                log.critical("This will cause random deadlocks, Refusing to start.")
                log.critical(
                    "See https://airflow.apache.org/docs/apache-airflow/stable/howto/"
                    "set-up-database.html#setting-up-a-mssql-database"
                )
                raise Exception("MSSQL database MUST have READ_COMMITTED_SNAPSHOT enabled.")
        finally:
            session.close()


DEFAULT_ENGINE_ARGS = {
    'postgresql': {
        'executemany_mode': 'values',
        'executemany_values_page_size': 10000,
        'executemany_batch_page_size': 2000,
    },
}


def prepare_engine_args(disable_connection_pool=False, pool_class=None):
    """Prepare SQLAlchemy engine args"""
    default_args = {}
    for dialect, default in DEFAULT_ENGINE_ARGS.items():
        if SQL_ALCHEMY_CONN.startswith(dialect):
            default_args = default.copy()
            break

    engine_args: dict = conf.getjson(
        'database', 'sql_alchemy_engine_args', fallback=default_args
    )  # type: ignore

    if pool_class:
        # Don't use separate settings for size etc, only those from sql_alchemy_engine_args
        engine_args['poolclass'] = pool_class
    elif disable_connection_pool or not conf.getboolean('database', 'SQL_ALCHEMY_POOL_ENABLED'):
        engine_args['poolclass'] = NullPool
        log.debug("settings.prepare_engine_args(): Using NullPool")
    elif not SQL_ALCHEMY_CONN.startswith('sqlite'):
        # Pool size engine args not supported by sqlite.
        # If no config value is defined for the pool size, select a reasonable value.
        # 0 means no limit, which could lead to exceeding the Database connection limit.
        pool_size = conf.getint('database', 'SQL_ALCHEMY_POOL_SIZE', fallback=5)

        # The maximum overflow size of the pool.
        # When the number of checked-out connections reaches the size set in pool_size,
        # additional connections will be returned up to this limit.
        # When those additional connections are returned to the pool, they are disconnected and discarded.
        # It follows then that the total number of simultaneous connections
        # the pool will allow is pool_size + max_overflow,
        # and the total number of "sleeping" connections the pool will allow is pool_size.
        # max_overflow can be set to -1 to indicate no overflow limit;
        # no limit will be placed on the total number
        # of concurrent connections. Defaults to 10.
        max_overflow = conf.getint('database', 'SQL_ALCHEMY_MAX_OVERFLOW', fallback=10)

        # The DB server already has a value for wait_timeout (number of seconds after
        # which an idle sleeping connection should be killed). Since other DBs may
        # co-exist on the same server, SQLAlchemy should set its
        # pool_recycle to an equal or smaller value.
        pool_recycle = conf.getint('database', 'SQL_ALCHEMY_POOL_RECYCLE', fallback=1800)

        # Check connection at the start of each connection pool checkout.
        # Typically, this is a simple statement like "SELECT 1", but may also make use
        # of some DBAPI-specific method to test the connection for liveness.
        # More information here:
        # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
        pool_pre_ping = conf.getboolean('database', 'SQL_ALCHEMY_POOL_PRE_PING', fallback=True)

        log.debug(
            "settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, "
            "pool_recycle=%d, pid=%d",
            pool_size,
            max_overflow,
            pool_recycle,
            os.getpid(),
        )
        engine_args['pool_size'] = pool_size
        engine_args['pool_recycle'] = pool_recycle
        engine_args['pool_pre_ping'] = pool_pre_ping
        engine_args['max_overflow'] = max_overflow

    # The default isolation level for MySQL (REPEATABLE READ) can introduce inconsistencies when
    # running multiple schedulers, as repeated queries on the same session may read from stale snapshots.
    # 'READ COMMITTED' is the default value for PostgreSQL.
    # More information here:
    # https://dev.mysql.com/doc/refman/8.0/en/innodb-transaction-isolation-levels.html"

    # Similarly MSSQL default isolation level should be set to READ COMMITTED.
    # We also make sure that READ_COMMITTED_SNAPSHOT option is on, in order to avoid deadlocks when
    # Select queries are running. This is by default enforced during init/upgrade. More information:
    # https://docs.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql

    if SQL_ALCHEMY_CONN.startswith(('mysql', 'mssql')):
        engine_args['isolation_level'] = 'READ COMMITTED'

    # Allow the user to specify an encoding for their DB otherwise default
    # to utf-8 so jobs & users with non-latin1 characters can still use us.
    engine_args['encoding'] = conf.get('database', 'SQL_ENGINE_ENCODING', fallback='utf-8')

    return engine_args


def dispose_orm():
    """Properly close pooled database connections"""
    log.debug("Disposing DB connection pool (PID %s)", os.getpid())
    global engine
    global Session

    if Session:
        Session.remove()
        Session = None
    if engine:
        engine.dispose()
        engine = None


def reconfigure_orm(disable_connection_pool=False, pool_class=None):
    """Properly close database connections and re-configure ORM"""
    dispose_orm()
    configure_orm(disable_connection_pool=disable_connection_pool, pool_class=pool_class)


def configure_adapters():
    """Register Adapters and DB Converters"""
    from pendulum import DateTime as Pendulum

    if SQL_ALCHEMY_CONN.startswith('sqlite'):
        from sqlite3 import register_adapter

        register_adapter(Pendulum, lambda val: val.isoformat(' '))

    if SQL_ALCHEMY_CONN.startswith('mysql'):
        try:
            import MySQLdb.converters

            MySQLdb.converters.conversions[Pendulum] = MySQLdb.converters.DateTime2literal
        except ImportError:
            pass
        try:
            import pymysql.converters

            pymysql.converters.conversions[Pendulum] = pymysql.converters.escape_datetime
        except ImportError:
            pass


def validate_session():
    """Validate ORM Session"""
    global engine

    worker_precheck = conf.getboolean('celery', 'worker_precheck', fallback=False)
    if not worker_precheck:
        return True
    else:
        check_session = sessionmaker(bind=engine)
        session = check_session()
        try:
            session.execute("select 1")
            conn_status = True
        except exc.DBAPIError as err:
            log.error(err)
            conn_status = False
        session.close()
        return conn_status


def configure_action_logging():
    """
    Any additional configuration (register callback) for airflow.utils.action_loggers
    module
    :rtype: None
    """


def prepare_syspath():
    """Ensures that certain subfolders of AIRFLOW_HOME are on the classpath"""
    if DAGS_FOLDER not in sys.path:
        sys.path.append(DAGS_FOLDER)

    # Add ./config/ for loading custom log parsers etc, or
    # airflow_local_settings etc.
    config_path = os.path.join(AIRFLOW_HOME, 'config')
    if config_path not in sys.path:
        sys.path.append(config_path)

    if PLUGINS_FOLDER not in sys.path:
        sys.path.append(PLUGINS_FOLDER)


def get_session_lifetime_config():
    """Gets session timeout configs and handles outdated configs gracefully."""
    session_lifetime_minutes = conf.get('webserver', 'session_lifetime_minutes', fallback=None)
    session_lifetime_days = conf.get('webserver', 'session_lifetime_days', fallback=None)
    uses_deprecated_lifetime_configs = session_lifetime_days or conf.get(
        'webserver', 'force_log_out_after', fallback=None
    )

    minutes_per_day = 24 * 60
    default_lifetime_minutes = '43200'
    if uses_deprecated_lifetime_configs and session_lifetime_minutes == default_lifetime_minutes:
        warnings.warn(
            '`session_lifetime_days` option from `[webserver]` section has been '
            'renamed to `session_lifetime_minutes`. The new option allows to configure '
            'session lifetime in minutes. The `force_log_out_after` option has been removed '
            'from `[webserver]` section. Please update your configuration.',
            category=RemovedInAirflow3Warning,
        )
        if session_lifetime_days:
            session_lifetime_minutes = minutes_per_day * int(session_lifetime_days)

    if not session_lifetime_minutes:
        session_lifetime_days = 30
        session_lifetime_minutes = minutes_per_day * session_lifetime_days

    logging.debug('User session lifetime is set to %s minutes.', session_lifetime_minutes)

    return int(session_lifetime_minutes)


def import_local_settings():
    """Import airflow_local_settings.py files to allow overriding any configs in settings.py file"""
    try:
        import airflow_local_settings

        if hasattr(airflow_local_settings, "__all__"):
            for i in airflow_local_settings.__all__:
                globals()[i] = getattr(airflow_local_settings, i)
        else:
            for k, v in airflow_local_settings.__dict__.items():
                if not k.startswith("__"):
                    globals()[k] = v

        # TODO: Remove once deprecated
        if "policy" in globals() and "task_policy" not in globals():
            warnings.warn(
                "Using `policy` in airflow_local_settings.py is deprecated. "
                "Please rename your `policy` to `task_policy`.",
                DeprecationWarning,
                stacklevel=2,
            )
            globals()["task_policy"] = globals()["policy"]
            del globals()["policy"]

        if not hasattr(task_instance_mutation_hook, 'is_noop'):
            task_instance_mutation_hook.is_noop = False

        log.info("Loaded airflow_local_settings from %s .", airflow_local_settings.__file__)
    except ModuleNotFoundError as e:
        if e.name == "airflow_local_settings":
            log.debug("No airflow_local_settings to import.", exc_info=True)
        else:
            log.critical(
                "Failed to import airflow_local_settings due to a transitive module not found error.",
                exc_info=True,
            )
            raise
    except ImportError:
        log.critical("Failed to import airflow_local_settings.", exc_info=True)
        raise


def initialize():
    """Initialize Airflow with all the settings from this file"""
    configure_vars()
    prepare_syspath()
    import_local_settings()
    global LOGGING_CLASS_PATH
    LOGGING_CLASS_PATH = configure_logging()
    configure_adapters()
    # The webservers import this file from models.py with the default settings.
    configure_orm()
    configure_action_logging()

    # Ensure we close DB connections at scheduler and gunicorn worker terminations
    atexit.register(dispose_orm)


# Const stuff

KILOBYTE = 1024
MEGABYTE = KILOBYTE * KILOBYTE
WEB_COLORS = {'LIGHTBLUE': '#4d9de0', 'LIGHTORANGE': '#FF9933'}


# Updating serialized DAG can not be faster than a minimum interval to reduce database
# write rate.
MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint('core', 'min_serialized_dag_update_interval', fallback=30)

# If set to True, serialized DAGs is compressed before writing to DB,
COMPRESS_SERIALIZED_DAGS = conf.getboolean('core', 'compress_serialized_dags', fallback=False)

# Fetching serialized DAG can not be faster than a minimum interval to reduce database
# read rate. This config controls when your DAGs are updated in the Webserver
MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint('core', 'min_serialized_dag_fetch_interval', fallback=10)

CAN_FORK = hasattr(os, "fork")

EXECUTE_TASKS_NEW_PYTHON_INTERPRETER = not CAN_FORK or conf.getboolean(
    'core',
    'execute_tasks_new_python_interpreter',
    fallback=False,
)

ALLOW_FUTURE_EXEC_DATES = conf.getboolean('scheduler', 'allow_trigger_in_future', fallback=False)

# Whether or not to check each dagrun against defined SLAs
CHECK_SLAS = conf.getboolean('core', 'check_slas', fallback=True)

USE_JOB_SCHEDULE = conf.getboolean('scheduler', 'use_job_schedule', fallback=True)

# By default Airflow plugins are lazily-loaded (only loaded when required). Set it to False,
# if you want to load plugins whenever 'airflow' is invoked via cli or loaded from module.
LAZY_LOAD_PLUGINS = conf.getboolean('core', 'lazy_load_plugins', fallback=True)

# By default Airflow providers are lazily-discovered (discovery and imports happen only when required).
# Set it to False, if you want to discover providers whenever 'airflow' is invoked via cli or
# loaded from module.
LAZY_LOAD_PROVIDERS = conf.getboolean('core', 'lazy_discover_providers', fallback=True)

# Determines if the executor utilizes Kubernetes
IS_K8S_OR_K8SCELERY_EXECUTOR = conf.get('core', 'EXECUTOR') in {
    executor_constants.KUBERNETES_EXECUTOR,
    executor_constants.CELERY_KUBERNETES_EXECUTOR,
    executor_constants.LOCAL_KUBERNETES_EXECUTOR,
}

HIDE_SENSITIVE_VAR_CONN_FIELDS = conf.getboolean('core', 'hide_sensitive_var_conn_fields')

# By default this is off, but is automatically configured on when running task
# instances
MASK_SECRETS_IN_LOGS = False

# Display alerts on the dashboard
# Useful for warning about setup issues or announcing changes to end users
# List of UIAlerts, which allows for specifying the message, category, and roles the
# message should be shown to. For example:
#   from airflow.www.utils import UIAlert
#
#   DASHBOARD_UIALERTS = [
#       UIAlert("Welcome to Airflow"),  # All users
#       UIAlert("Airflow update happening next week", roles=["User"]),  # Only users with the User role
#       # A flash message with html:
#       UIAlert('Visit <a href="http://airflow.apache.org">airflow.apache.org</a>', html=True),
#   ]
#
DASHBOARD_UIALERTS: list[UIAlert] = []

# Prefix used to identify tables holding data moved during migration.
AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved"

DAEMON_UMASK: str = conf.get('core', 'daemon_umask', fallback='0o077')

相关信息

airflow 源码目录

相关文章

airflow init 源码

airflow main 源码

airflow configuration 源码

airflow exceptions 源码

airflow logging_config 源码

airflow plugins_manager 源码

airflow providers_manager 源码

airflow sentry 源码

airflow stats 源码

airflow templates 源码

0  赞