blob: b160d6f877558f0001ea45ea584e9a16d4cbe182 [file] [log] [blame]
# Copyright (C) 2019 Apple Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import mock
import re
import time
from cassandra.metadata import ColumnMetadata, KeyspaceMetadata, Metadata, TableMetadataV3
from cassandra.util import OrderedDict
from collections import defaultdict
from resultsdbpy.model.cassandra_context import CassandraContext
class MockCluster(object):
metadata = Metadata()
database = {}
def __init__(self, contact_points=["127.0.0.1"], port=9042, auth_provider=None):
assert isinstance(contact_points, list)
self.contact_points = contact_points
self.port = port
self.sessions = []
def connect(self, keyspace=None):
result = MockSession(self, keyspace)
self.sessions.append(result)
return result
def shutdown(self):
self.sessions = []
class MockSession(object):
def __init__(self, cluster, keyspace=None):
self.cluster = cluster
self.keyspace = keyspace
def execute(self, *args):
pass
class MockCQLEngineContext():
connections = {}
def __init__(self):
self._depth = 0
self.patches = [
mock.patch('resultsdbpy.model.cassandra_context.Cluster', new=MockCluster),
mock.patch('resultsdbpy.model.cassandra_context.register_connection', new=self.register_connection),
mock.patch('resultsdbpy.model.cassandra_context.unregister_connection', new=self.unregister_connection),
mock.patch('resultsdbpy.model.cassandra_context.get_cluster', new=self.get_cluster),
mock.patch('resultsdbpy.model.cassandra_context.create_keyspace_simple', new=self.create_keyspace_simple),
mock.patch('resultsdbpy.model.cassandra_context.create_keyspace_network_topology', new=self.create_keyspace_network_topology),
mock.patch('resultsdbpy.model.cassandra_context.drop_keyspace', new=self.drop_keyspace),
mock.patch('resultsdbpy.model.cassandra_context.sync_table', new=self.sync_table),
]
def __enter__(self):
if self._depth == 0:
for patch in self.patches:
patch.__enter__()
self._depth += 1
def __exit__(self, *args, **kwargs):
self._depth -= 1
if self._depth <= 0:
for patch in self.patches:
patch.__exit__(*args, **kwargs)
@classmethod
def register_connection(cls, name, session, **kwargs):
if session is None:
session = MockCluster().connect()
cls.connections[name] = (session.cluster, session)
@classmethod
def unregister_connection(cls, name):
if name in cls.connections:
del cls.connections[name]
@classmethod
def get_cluster(cls, connection):
return cls.connections.get(connection, [None])[0]
@classmethod
def create_keyspace_simple(cls, name, replication_factor=1, durable_writes=True, connections=[]):
for connection in connections:
assert connection in cls.connections
if name not in cls.get_cluster(connection).metadata.keyspaces:
cls.get_cluster(connection).metadata.keyspaces[name] = KeyspaceMetadata(name, durable_writes, None, None)
MockCluster.database[name] = {}
@classmethod
def create_keyspace_network_topology(cls, name, dc_replication_map={}, durable_writes=True, connections=[]):
return cls.create_keyspace_simple(name, durable_writes=durable_writes, connections=connections)
@classmethod
def drop_keyspace(cls, name, connections=[]):
for connection in connections:
assert connection in cls.connections
if name in cls.get_cluster(connection).metadata.keyspaces:
del cls.get_cluster(connection).metadata.keyspaces[name]
del MockCluster.database[name]
@classmethod
def sync_table(cls, model, keyspaces=[], connections=[]):
for connection in connections:
assert connection in cls.connections
for keyspace in keyspaces:
keyspace_metadata = MockCQLEngineContext.get_cluster(connection).metadata.keyspaces.get(keyspace)
assert keyspace_metadata is not None
primary_keys = []
column_keys = []
for key, value in model._columns.items():
if value.partition_key or value.primary_key:
primary_keys.append(key)
else:
column_keys.append(key)
columns = OrderedDict()
partition_keys = []
clustering_keys = []
for key in primary_keys + column_keys:
value = model._columns[key]
meta_data = ColumnMetadata(None, key, value.db_type, is_reversed=(value.clustering_order == 'DESC'))
columns[key] = meta_data
if value.partition_key:
partition_keys.append(meta_data)
if value.primary_key:
clustering_keys.append(meta_data)
MockCluster.database[keyspace][model._raw_column_family_name()] = defaultdict(lambda: {})
keyspace_metadata.tables[model._raw_column_family_name()] = TableMetadataV3(
keyspace, model._raw_column_family_name(),
columns=columns, partition_key=partition_keys, clustering_key=clustering_keys,
)
class MockCassandraContext(CassandraContext):
@classmethod
def drop_keyspace(cls, *args, **kwargs):
with MockCQLEngineContext():
CassandraContext.drop_keyspace(*args, **kwargs)
def __init__(self, *args, **kwargs):
self._mock_cql_engine = MockCQLEngineContext()
with self._mock_cql_engine:
super(MockCassandraContext, self).__init__(*args, **kwargs)
def __enter__(self):
self._mock_cql_engine.__enter__()
super(MockCassandraContext, self).__enter__()
def __exit__(self, *args, **kwargs):
super(MockCassandraContext, self).__exit__(*args, **kwargs)
self._mock_cql_engine.__exit__(*args, **kwargs)
@CassandraContext.AssertConnectedDecorator()
def insert_row(self, table_name, ttl=None, **kwargs):
if table_name not in self._models:
raise self.SchemaException(f'{table_name} does not exist in the database')
item = self._models[table_name](**kwargs)
schema = self.schema_for_table(table_name)
partition = []
for column in schema.partition_key:
partition.append(getattr(item, column.name))
cluster = []
for column in schema.primary_key:
if getattr(item, column.name) not in partition:
cluster.append(getattr(item, column.name))
self.cluster.database[self.keyspace][table_name][tuple(partition)][tuple(cluster)] = (None if ttl is None else time.time() + ttl, item)
@CassandraContext.AssertConnectedDecorator()
def select_from_table(self, table_name, limit=10000, **kwargs):
result = []
schema = self.schema_for_table(table_name)
partition_names = []
for column in schema.partition_key:
partition_names.append(column.name)
# Sort kwargs so they are in the same order as the columns.
ordered_kwargs = OrderedDict()
for name in self._models[table_name]._columns.keys():
for key, value in kwargs.items():
if key.split('__')[0] == name and value is not None:
ordered_kwargs[key] = value
# Convert arguments to filters
partitions = [[]]
filters = []
for key, value in ordered_kwargs.items():
key_value = key.split('__')[0]
operator = None if len(key.split('__')) == 1 else key.split('__')[1]
if key_value in partition_names:
if operator == 'in':
original_partitions = list(partitions)
partitions = []
for _ in range(len(value)):
for element in original_partitions:
partitions.append(list(element))
for multiplier in range(len(value)):
for index in range(len(original_partitions)):
partitions[multiplier * len(original_partitions) + index].append(value[multiplier])
else:
for partition in partitions:
partition.append(value)
filters.append(self.filter_for_argument(key, value))
if len(partitions[0]) == 0:
partitions = self.cluster.database[self.keyspace][table_name].keys()
candidate_elements = []
for partition in partitions:
for key, row in self.cluster.database[self.keyspace][table_name][tuple(partition)].items():
does_match = True
for filter in filters:
if not filter(row[1]):
does_match = False
break
if not does_match:
continue
candidate_elements.append((partition, key))
# To reverse, they all need to agree. This might not work with compound clustering keys, but we really
# shouldn't use those anyways
is_reversed = True
for column in schema.primary_key:
if column in schema.partition_key:
continue
if not column.is_reversed:
is_reversed = False
candidate_elements.sort(reverse=is_reversed)
for candidate in candidate_elements:
element = self.cluster.database[self.keyspace][table_name][tuple(candidate[0])][candidate[1]]
deadline = element[0]
if deadline is None or deadline > time.time():
result.append(element[1])
limit -= 1
if limit <= 0:
return result
else:
del self.cluster.database[self.keyspace][table_name][tuple(candidate[0])][candidate[1]]
return result