|  | 
| 27 | 27 | import json | 
| 28 | 28 | import logging | 
| 29 | 29 | import os | 
|  | 30 | +import signal | 
|  | 31 | +import socket | 
| 30 | 32 | import sys | 
| 31 | 33 | import time | 
| 32 |  | -import signal | 
| 33 |  | -import yaml | 
| 34 | 34 | import uuid | 
|  | 35 | +import yaml | 
| 35 | 36 | 
 | 
| 36 | 37 | import boto3 | 
| 37 | 38 | from botocore.handlers import validate_bucket_name | 
| 38 | 39 | import cloudevents.http | 
| 39 | 40 | import confluent_kafka as kafka | 
|  | 41 | +from confluent_kafka.serialization import SerializationContext, MessageField | 
|  | 42 | +from confluent_kafka.schema_registry import SchemaRegistryClient | 
|  | 43 | +from confluent_kafka.schema_registry.avro import AvroDeserializer | 
| 40 | 44 | import flask | 
| 41 | 45 | 
 | 
| 42 | 46 | from .config import PipelinesConfig | 
|  | 
| 53 | 57 | from .repo_tracker import LocalRepoTracker | 
| 54 | 58 | from .visit import FannedOutVisit | 
| 55 | 59 | 
 | 
| 56 |  | - | 
|  | 60 | +# Platform that prompt processing will run on | 
|  | 61 | +platform = os.environ["PLATFORM"].lower() | 
| 57 | 62 | # The short name for the instrument. | 
| 58 | 63 | instrument_name = os.environ["RUBIN_INSTRUMENT"] | 
| 59 | 64 | # The skymap to use in the central repo | 
|  | 
| 76 | 81 | kafka_group_id = str(uuid.uuid4()) | 
| 77 | 82 | # The topic on which to listen to updates to image_bucket | 
| 78 | 83 | bucket_topic = os.environ.get("BUCKET_TOPIC", "rubin-prompt-processing") | 
|  | 84 | +# Offset for Kafka bucket notification. | 
|  | 85 | +bucket_notification_kafka_offset_reset = os.environ.get("BUCKET_NOTIFICATION_KAFKA_OFFSET_RESET", "latest") | 
|  | 86 | + | 
|  | 87 | +# Conditionally load keda environment variables | 
|  | 88 | +if platform == "keda": | 
|  | 89 | +    # Kafka Schema Registry URL for next visit fan out messages | 
|  | 90 | +    fan_out_schema_registry_url = os.environ["FAN_OUT_SCHEMA_REGISTRY_URL"] | 
|  | 91 | +    # Kafka cluster with next visit fanned out messages. | 
|  | 92 | +    fan_out_kafka_cluster = os.environ["FAN_OUT_KAFKA_CLUSTER"] | 
|  | 93 | +    # Kafka group for next visit fan out messages. | 
|  | 94 | +    fan_out_kafka_group_id = os.environ["FAN_OUT_KAFKA_GROUP_ID"] | 
|  | 95 | +    # Kafka topic for next visit fan out messages. | 
|  | 96 | +    fan_out_kafka_topic = os.environ["FAN_OUT_KAFKA_TOPIC"] | 
|  | 97 | +    # Kafka topic offset for next visit fan out messages. | 
|  | 98 | +    fan_out_kafka_topic_offset = os.environ["FAN_OUT_KAFKA_TOPIC_OFFSET"] | 
|  | 99 | +    # Kafka Fan Out SASL Mechansim. | 
|  | 100 | +    fan_out_kafka_sasl_mechanism = os.environ["FAN_OUT_KAFKA_SASL_MECHANISM"] | 
|  | 101 | +    # Kafka Fan Out Security Protocol. | 
|  | 102 | +    fan_out_kafka_security_protocol = os.environ["FAN_OUT_KAFKA_SECURITY_PROTOCOL"] | 
|  | 103 | +    # Kafka Fan Out Consumer Username. | 
|  | 104 | +    fan_out_kafka_sasl_username = os.environ["FAN_OUT_KAFKA_SASL_USERNAME"] | 
|  | 105 | +    # Kafka Fan Out Consumer Password. | 
|  | 106 | +    fan_out_kafka_sasl_password = os.environ["FAN_OUT_KAFKA_SASL_PASSWORD"] | 
|  | 107 | +    # Time to wait for fanned out messages before spawning new pod. | 
|  | 108 | +    fanned_out_msg_listen_timeout = int(os.environ.get("FANNED_OUT_MSG_LISTEN_TIMEOUT", 300)) | 
| 79 | 109 | 
 | 
| 80 | 110 | _log = logging.getLogger("lsst." + __name__) | 
| 81 | 111 | _log.setLevel(logging.DEBUG) | 
| @@ -127,7 +157,7 @@ def _get_consumer(): | 
| 127 | 157 |     return kafka.Consumer({ | 
| 128 | 158 |         "bootstrap.servers": kafka_cluster, | 
| 129 | 159 |         "group.id": kafka_group_id, | 
| 130 |  | -        "auto.offset.reset": "latest",  # default, but make explicit | 
|  | 160 | +        "auto.offset.reset": bucket_notification_kafka_offset_reset, | 
| 131 | 161 |     }) | 
| 132 | 162 | 
 | 
| 133 | 163 | 
 | 
| @@ -195,6 +225,107 @@ def create_app(): | 
| 195 | 225 |         sys.exit(3) | 
| 196 | 226 | 
 | 
| 197 | 227 | 
 | 
|  | 228 | +def dict_to_fanned_out_visit(obj, ctx): | 
|  | 229 | +    """ | 
|  | 230 | +    Converts object literal(dict) to a Fanned Out instance. | 
|  | 231 | +    Args: | 
|  | 232 | +        ctx (SerializationContext): Metadata pertaining to the serialization | 
|  | 233 | +            operation. | 
|  | 234 | +        obj (dict): Object literal(dict) | 
|  | 235 | +    """ | 
|  | 236 | + | 
|  | 237 | +    if obj is None: | 
|  | 238 | +        return None | 
|  | 239 | + | 
|  | 240 | +    return FannedOutVisit(**obj) | 
|  | 241 | + | 
|  | 242 | + | 
|  | 243 | +def keda_start(): | 
|  | 244 | + | 
|  | 245 | +    try: | 
|  | 246 | +        setup_usdf_logger( | 
|  | 247 | +            labels={"instrument": instrument_name}, | 
|  | 248 | +        ) | 
|  | 249 | + | 
|  | 250 | +        # Initialize local registry | 
|  | 251 | +        registry = LocalRepoTracker.get() | 
|  | 252 | +        registry.init_tracker() | 
|  | 253 | + | 
|  | 254 | +        # Check initialization and abort early | 
|  | 255 | +        _get_consumer() | 
|  | 256 | +        _get_storage_client() | 
|  | 257 | +        _get_central_butler() | 
|  | 258 | +        _get_local_repo() | 
|  | 259 | + | 
|  | 260 | +        _log.info("Worker ready to handle requests.") | 
|  | 261 | + | 
|  | 262 | +    except Exception as e: | 
|  | 263 | +        _log.critical("Failed to start worker; aborting.") | 
|  | 264 | +        _log.exception(e) | 
|  | 265 | +        sys.exit(1) | 
|  | 266 | + | 
|  | 267 | +    # Initialize schema registry for fan out | 
|  | 268 | +    fan_out_schema_registry_conf = {'url': fan_out_schema_registry_url} | 
|  | 269 | +    fan_out_schema_registry_client = SchemaRegistryClient(fan_out_schema_registry_conf) | 
|  | 270 | + | 
|  | 271 | +    fan_out_avro_deserializer = AvroDeserializer(schema_registry_client=fan_out_schema_registry_client, | 
|  | 272 | +                                                 from_dict=dict_to_fanned_out_visit) | 
|  | 273 | +    fan_out_consumer_conf = { | 
|  | 274 | +        "bootstrap.servers": fan_out_kafka_cluster, | 
|  | 275 | +        "group.id": fan_out_kafka_group_id, | 
|  | 276 | +        "auto.offset.reset": fan_out_kafka_topic_offset, | 
|  | 277 | +        "sasl.mechanism": fan_out_kafka_sasl_mechanism, | 
|  | 278 | +        "security.protocol": fan_out_kafka_security_protocol, | 
|  | 279 | +        "sasl.username": fan_out_kafka_sasl_username, | 
|  | 280 | +        "sasl.password": fan_out_kafka_sasl_password, | 
|  | 281 | +        'enable.auto.commit': False | 
|  | 282 | +    } | 
|  | 283 | + | 
|  | 284 | +    _log.info("starting fan out consumer") | 
|  | 285 | +    fan_out_consumer = kafka.Consumer(fan_out_consumer_conf, logger=_log) | 
|  | 286 | +    fan_out_consumer.subscribe([fan_out_kafka_topic]) | 
|  | 287 | +    fan_out_listen_start_time = time.time() | 
|  | 288 | + | 
|  | 289 | +    try: | 
|  | 290 | +        while time.time() - fan_out_listen_start_time < fanned_out_msg_listen_timeout: | 
|  | 291 | + | 
|  | 292 | +            fan_out_message = fan_out_consumer.poll(timeout=5) | 
|  | 293 | +            if fan_out_message is None: | 
|  | 294 | +                continue | 
|  | 295 | +            if fan_out_message.error(): | 
|  | 296 | +                _log.warning("Fanned out consumer error: %s", fan_out_message.error()) | 
|  | 297 | +            else: | 
|  | 298 | +                deserialized_fan_out_visit = fan_out_avro_deserializer(fan_out_message.value(), | 
|  | 299 | +                                                                       SerializationContext( | 
|  | 300 | +                                                                       fan_out_message.topic(), | 
|  | 301 | +                                                                       MessageField.VALUE)) | 
|  | 302 | +                _log.info("Unpacked message as %r.", deserialized_fan_out_visit) | 
|  | 303 | + | 
|  | 304 | +                # Calculate time to load knative and receive message based on timestamp in Kafka message | 
|  | 305 | +                _log.debug("Message timestamp %r", fan_out_message.timestamp()) | 
|  | 306 | +                fan_out_kafka_msg_timestamp = fan_out_message.timestamp() | 
|  | 307 | +                fan_out_to_prompt_time = int(time.time() * 1000) - fan_out_kafka_msg_timestamp[1] | 
|  | 308 | +                _log.debug("Seconds since fan out message delivered %r", fan_out_to_prompt_time/1000) | 
|  | 309 | + | 
|  | 310 | +                # Commit message and close client | 
|  | 311 | +                fan_out_consumer.commit(message=fan_out_message, asynchronous=False) | 
|  | 312 | +                fan_out_consumer.close() | 
|  | 313 | + | 
|  | 314 | +                try: | 
|  | 315 | +                    # Process fan out visit | 
|  | 316 | +                    process_visit(deserialized_fan_out_visit) | 
|  | 317 | +                except Exception as e: | 
|  | 318 | +                    _log.critical("Process visit failed; aborting.") | 
|  | 319 | +                    _log.exception(e) | 
|  | 320 | +                finally: | 
|  | 321 | +                    _log.info("Processing completed for %s", socket.gethostname()) | 
|  | 322 | +                    break | 
|  | 323 | + | 
|  | 324 | +    finally: | 
|  | 325 | +        # TODO Handle local registry unregistration on DM-47975 | 
|  | 326 | +        _log.info("Finished listening for fanned out messages") | 
|  | 327 | + | 
|  | 328 | + | 
| 198 | 329 | def _graceful_shutdown(signum: int, stack_frame): | 
| 199 | 330 |     """Signal handler for cases where the service should gracefully shut down. | 
| 200 | 331 | 
 | 
| @@ -586,10 +717,18 @@ def server_error(e: Exception) -> tuple[str, int]: | 
| 586 | 717 | 
 | 
| 587 | 718 | 
 | 
| 588 | 719 | def main(): | 
| 589 |  | -    # This function is only called in test environments. Container | 
| 590 |  | -    # deployments call `create_app()()` through Gunicorn. | 
| 591 |  | -    app = create_app() | 
| 592 |  | -    app.run(host="127.0.0.1", port=8080, debug=True) | 
|  | 720 | +    # Knative deployments call `create_app()()` through Gunicorn. | 
|  | 721 | +    # Keda deployments invoke main. | 
|  | 722 | +    if platform == "knative": | 
|  | 723 | +        _log.info("starting standalone Flask app") | 
|  | 724 | +        app = create_app() | 
|  | 725 | +        app.run(host="127.0.0.1", port=8080, debug=True) | 
|  | 726 | +    # starts keda instance of the application | 
|  | 727 | +    elif platform == "keda": | 
|  | 728 | +        _log.info("starting keda instance") | 
|  | 729 | +        keda_start() | 
|  | 730 | +    else: | 
|  | 731 | +        _log.info("no platform defined") | 
| 593 | 732 | 
 | 
| 594 | 733 | 
 | 
| 595 | 734 | if __name__ == "__main__": | 
|  | 
0 commit comments