File: //proc/self/root/usr/lib/python3/dist-packages/awscli/customizations/history/db.py
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import uuid
import time
import json
import datetime
import threading
import logging
from awscli.compat import collections_abc
from botocore.history import BaseHistoryHandler
from awscli.compat import sqlite3
from awscli.compat import binary_type
LOG = logging.getLogger(__name__)
class DatabaseConnection(object):
    _CREATE_TABLE = """
        CREATE TABLE IF NOT EXISTS records (
          id TEXT,
          request_id TEXT,
          source TEXT,
          event_type TEXT,
          timestamp INTEGER,
          payload TEXT
        )"""
    _ENABLE_WAL = 'PRAGMA journal_mode=WAL'
    def __init__(self, db_filename):
        self._connection = sqlite3.connect(
            db_filename, check_same_thread=False, isolation_level=None)
        self._ensure_database_setup()
    def close(self):
        self._connection.close()
    def execute(self, query, *parameters):
        return self._connection.execute(query, *parameters)
    def _ensure_database_setup(self):
        self._create_record_table()
        self._try_to_enable_wal()
    def _create_record_table(self):
        self.execute(self._CREATE_TABLE)
    def _try_to_enable_wal(self):
        try:
            self.execute(self._ENABLE_WAL)
        except sqlite3.Error:
            # This is just a performance enhancement so it is optional. Not all
            # systems will have a sqlite compiled with the WAL enabled.
            LOG.debug('Failed to enable sqlite WAL.')
    @property
    def row_factory(self):
        return self._connection.row_factory
    @row_factory.setter
    def row_factory(self, row_factory):
        self._connection.row_factory = row_factory
class PayloadSerializer(json.JSONEncoder):
    def _encode_mutable_mapping(self, obj):
        return dict(obj)
    def _encode_datetime(self, obj):
        return obj.isoformat()
    def _try_decode_bytes(self, obj):
        try:
            obj = obj.decode('utf-8')
        except UnicodeDecodeError:
            obj = '<Byte sequence>'
        return obj
    def _remove_non_unicode_stings(self, obj):
        if isinstance(obj, str):
            obj = self._try_decode_bytes(obj)
        elif isinstance(obj, dict):
            obj = dict((k, self._remove_non_unicode_stings(v)) for k, v
                       in obj.items())
        elif isinstance(obj, (list, tuple)):
            obj = [self._remove_non_unicode_stings(o) for o in obj]
        return obj
    def encode(self, obj):
        try:
            return super(PayloadSerializer, self).encode(obj)
        except UnicodeDecodeError:
            # This happens in PY2 in the case where a record payload has some
            # binary data in it that is not utf-8 encodable. PY2 will not call
            # the default method on the individual field with bytes in it since
            # it thinks it can handle it with the normal string serialization
            # method. Since it cannot tell the difference between a utf-8 str
            # and a str with raw bytes in it we will get a UnicodeDecodeError
            # here at the top level. There are no hooks into the serialization
            # process in PY2 that allow us to fix this behavior, so instead
            # when we encounter the unicode error we climb the structure
            # ourselves and replace all strings that are not utf-8 decodable
            # and try to encode again.
            scrubbed_obj = self._remove_non_unicode_stings(obj)
            return super(PayloadSerializer, self).encode(scrubbed_obj)
    def default(self, obj):
        if isinstance(obj, datetime.datetime):
            return self._encode_datetime(obj)
        elif isinstance(obj, collections_abc.MutableMapping):
            return self._encode_mutable_mapping(obj)
        elif isinstance(obj, binary_type):
            # In PY3 the bytes type differs from the str type so the default
            # method will be called when a bytes object is encountered.
            # We call the same _try_decode_bytes method that either decodes it
            # to a utf-8 string and continues serialization, or removes the
            # value if it is not valid utf-8 string.
            return self._try_decode_bytes(obj)
        else:
            return repr(obj)
class DatabaseRecordWriter(object):
    _WRITE_RECORD = """
        INSERT INTO records(
            id, request_id, source, event_type, timestamp, payload)
        VALUES (?,?,?,?,?,?) """
    def __init__(self, connection):
        self._connection = connection
        self._lock = threading.Lock()
    def close(self):
        self._connection.close()
    def write_record(self, record):
        db_record = self._create_db_record(record)
        with self._lock:
            self._connection.execute(self._WRITE_RECORD, db_record)
    def _create_db_record(self, record):
        event_type = record['event_type']
        json_serialized_payload = json.dumps(record['payload'],
                                             cls=PayloadSerializer)
        db_record = (
            record['command_id'],
            record.get('request_id'),
            record['source'],
            event_type,
            record['timestamp'],
            json_serialized_payload
        )
        return db_record
class DatabaseRecordReader(object):
    _ORDERING = 'ORDER BY timestamp'
    _GET_LAST_ID_RECORDS = """
        SELECT * FROM records
        WHERE id =
        (SELECT id FROM records WHERE timestamp =
        (SELECT max(timestamp) FROM records)) %s;""" % _ORDERING
    _GET_RECORDS_BY_ID = 'SELECT * from records where id = ? %s' % _ORDERING
    _GET_ALL_RECORDS = (
        'SELECT a.id AS id_a, '
        '    b.id AS id_b, '
        '    a.timestamp as timestamp, '
        '    a.payload AS args, '
        '    b.payload AS rc '
        'FROM records a, records b '
        'where a.event_type == "CLI_ARGUMENTS" AND '
        '    b.event_type = "CLI_RC" AND '
        '    id_a == id_b '
        '%s DESC' % _ORDERING
    )
    def __init__(self, connection):
        self._connection = connection
        self._connection.row_factory = self._row_factory
    def close(self):
        self._connection.close()
    def _row_factory(self, cursor, row):
        d = {}
        for idx, col in enumerate(cursor.description):
            val = row[idx]
            if col[0] == 'payload':
                val = json.loads(val)
            d[col[0]] = val
        return d
    def iter_latest_records(self):
        cursor = self._connection.execute(self._GET_LAST_ID_RECORDS)
        for row in cursor:
            yield row
    def iter_records(self, record_id):
        cursor = self._connection.execute(self._GET_RECORDS_BY_ID, [record_id])
        for row in cursor:
            yield row
    def iter_all_records(self):
        cursor = self._connection.execute(self._GET_ALL_RECORDS)
        for row in cursor:
            yield row
class RecordBuilder(object):
    _REQUEST_LIFECYCLE_EVENTS = set(
        ['API_CALL', 'HTTP_REQUEST', 'HTTP_RESPONSE', 'PARSED_RESPONSE'])
    _START_OF_REQUEST_LIFECYCLE_EVENT = 'API_CALL'
    def __init__(self):
        self._identifier = None
        self._locals = threading.local()
    def _get_current_thread_request_id(self):
        request_id = getattr(self._locals, 'request_id', None)
        return request_id
    def _start_http_lifecycle(self):
        setattr(self._locals, 'request_id', str(uuid.uuid4()))
    def _get_request_id(self, event_type):
        if event_type == self._START_OF_REQUEST_LIFECYCLE_EVENT:
            self._start_http_lifecycle()
        if event_type in self._REQUEST_LIFECYCLE_EVENTS:
            request_id = self._get_current_thread_request_id()
            return request_id
        return None
    def _get_identifier(self):
        if self._identifier is None:
            self._identifier = str(uuid.uuid4())
        return self._identifier
    def build_record(self, event_type, payload, source):
        uid = self._get_identifier()
        record = {
            'command_id': uid,
            'event_type': event_type,
            'payload': payload,
            'source': source,
            'timestamp': int(time.time() * 1000)
        }
        request_id = self._get_request_id(event_type)
        if request_id:
            record['request_id'] = request_id
        return record
class DatabaseHistoryHandler(BaseHistoryHandler):
    def __init__(self, writer, record_builder):
        self._writer = writer
        self._record_builder = record_builder
    def emit(self, event_type, payload, source):
        record = self._record_builder.build_record(event_type, payload, source)
        self._writer.write_record(record)