blob: 834254ca9dc3a1941115afb6666bbac1407bc07a [file] [log] [blame]
# Copyright (C) 2019 Apple Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import contextlib
import os
import re
import uuid
from cassandra.cluster import Cluster
from cassandra.cqlengine.columns import Text
from cassandra.cqlengine.connection import register_connection, unregister_connection
from cassandra.cqlengine.management import CQLENG_ALLOW_SCHEMA_MANAGEMENT, get_cluster, create_keyspace_network_topology, create_keyspace_simple, drop_keyspace, sync_table
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.query import BatchQuery
class RegexCluster(Text):
def __init__(self, min_length=1, max_length=None, primary_key=True, partition_key=False, **kwargs):
assert primary_key
assert not partition_key
super(RegexCluster, self).__init__(min_length=min_length, max_length=max_length, primary_key=primary_key, partition_key=partition_key, **kwargs)
class CountedBatchQuery(BatchQuery):
DEFAULT_LIMIT = 100
def __init__(self, limit=DEFAULT_LIMIT, **kwargs):
super(CountedBatchQuery, self).__init__(**kwargs)
self.limit = limit
def add_query(self, query):
if len(self.queries) >= self.limit:
self.execute()
self._executed = False
return super(CountedBatchQuery, self).add_query(query)
class CassandraContext(object):
@classmethod
def can_modify_schema(cls):
return os.getenv(CQLENG_ALLOW_SCHEMA_MANAGEMENT, False)
@classmethod
def drop_keyspace(cls, nodes=None, keyspace='results_database', auth_provider=None):
nodes = nodes if nodes else ['localhost']
connection_id = uuid.uuid4()
try:
register_connection(name=str(connection_id), session=Cluster(nodes, auth_provider=auth_provider).connect())
does_keyspace_exist = keyspace in get_cluster(str(connection_id)).metadata.keyspaces
if does_keyspace_exist:
drop_keyspace(keyspace, connections=[str(connection_id)])
finally:
unregister_connection(name=str(connection_id))
def __init__(self, nodes=None, keyspace='results_database', auth_provider=None, create_keyspace=False, replication_map=None):
self.keyspace = keyspace
self._depth = 0
self._connection_id = uuid.uuid4()
self._models = {}
self._nodes = nodes if nodes else ['localhost']
self._auth_provider = auth_provider
self._batch = []
try:
register_connection(name=str(self._connection_id), session=Cluster(self._nodes, auth_provider=self._auth_provider).connect())
does_keyspace_exist = self.keyspace in get_cluster(str(self._connection_id)).metadata.keyspaces
if create_keyspace and not does_keyspace_exist:
if not self.can_modify_schema():
raise Exception('Cannot create keyspace, Schema modification is disabled')
if replication_map is None:
create_keyspace_simple(self.keyspace, replication_factor=1, connections=[str(self._connection_id)])
else:
create_keyspace_network_topology(self.keyspace, dc_replication_map=replication_map, connections=[str(self._connection_id)])
elif not does_keyspace_exist:
raise Exception(f'Keyspace {self.keyspace} does not exist and will not be created')
finally:
unregister_connection(name=str(self._connection_id))
def __enter__(self):
if self._depth == 0:
register_connection(name=str(self._connection_id), session=Cluster(self._nodes, auth_provider=self._auth_provider).connect(keyspace=self.keyspace))
self._depth += 1
def __exit__(self, *args, **kwargs):
self._depth -= 1
if self._depth <= 0:
unregister_connection(name=str(self._connection_id))
def assert_connected(self):
if self._depth <= 0:
raise AssertionError('No Cassandra session available')
class AssertConnectedDecorator():
def __call__(self, function):
def decorator(obj, *args, **kwargs):
obj.assert_connected()
return function(obj, *args, **kwargs)
return decorator
@property
@AssertConnectedDecorator()
def cluster(self):
return get_cluster(str(self._connection_id))
def schema_for_table(self, table_name):
if not self.cluster.metadata:
return None
keyspace_metadata = self.cluster.metadata.keyspaces.get(self.keyspace, None)
if not keyspace_metadata:
return None
return keyspace_metadata.tables.get(table_name)
@AssertConnectedDecorator()
def create_table(self, model):
does_schema_match = self.does_table_model_match_schema(model)
if does_schema_match is False:
raise self.SchemaException('Existing schema does not match provided model')
table_name = model._raw_column_family_name()
self._models[table_name] = model
self._models[table_name].__connection__ = str(self._connection_id)
self._models[table_name].__keyspace__ = self.keyspace
if does_schema_match:
return
assert self.can_modify_schema()
# We have a special model named RegexCluster which allows LIKE operations to be preformed on a primary key quickly.
# This was specifically intended for git commits, although has a few other potential uses.
sasi_index_column = None
for attr in dir(model):
if isinstance(getattr(getattr(model, attr, None), 'column', None), RegexCluster):
if sasi_index_column:
raise self.SchemaException('Only one RegexCluster allowed')
sasi_index_column = attr
sync_table(model, keyspaces=[self.keyspace], connections=[str(self._connection_id)])
if sasi_index_column:
for session in self.cluster.sessions:
if session.keyspace != self.keyspace:
continue
# https://docs.datastax.com/en/dse/5.1/cql/cql/cql_using/useSASIIndex.html
session.execute(f"""CREATE CUSTOM INDEX index_{table_name}_{sasi_index_column} ON {table_name} ({sasi_index_column}) USING \
'org.apache.cassandra.index.sasi.SASIIndex' WITH OPTIONS = {{ \
'analyzer_class': 'org.apache.cassandra.index.sasi.analyzer.StandardAnalyzer', \
'case_sensitive': 'true'}}""")
break
@AssertConnectedDecorator()
def does_table_model_match_schema(self, model):
if not issubclass(model, Model):
raise self.SchemaException('Models must be derived from base Model.')
if model.__abstract__:
raise self.SchemaException('Cannot create table from abstract model')
schema = self.schema_for_table(model._raw_column_family_name())
if schema is None:
return None
primary_columns = []
data_columns = []
for key in model._columns.keys():
if getattr(model, key).column.partition_key or getattr(model, key).column.primary_key:
primary_columns.append(key)
else:
data_columns.append(key)
schema_columns = []
for column in schema.columns:
schema_columns.append(column)
if len(primary_columns) + len(data_columns) != len(schema_columns):
return False
for i in range(len(primary_columns)):
if primary_columns[i] != schema_columns[i]:
return False
for element in data_columns:
if element not in schema_columns:
return False
partition_keys = [column.name for column in schema.partition_key]
primary_keys = [column.name for column in schema.primary_key]
for column in primary_columns + data_columns:
model_column = getattr(model, column).column
if schema.columns[column].cql_type != model_column.db_type:
return False
if model_column.partition_key and column not in partition_keys:
return False
if model_column.primary_key and column not in primary_keys:
return False
if model_column.clustering_order == 'DESC' and not schema.columns[column].is_reversed:
return False
if (model_column.clustering_order == 'ASC' or model_column.clustering_order is None) and schema.columns[column].is_reversed:
return False
return True
@AssertConnectedDecorator()
@contextlib.contextmanager
def batch_query_context(self, limit=CountedBatchQuery.DEFAULT_LIMIT):
self._batch.append(CountedBatchQuery(limit=limit, connection=str(self._connection_id)))
try:
with self._batch[-1]:
yield
finally:
del self._batch[-1]
@AssertConnectedDecorator()
def insert_row(self, table_name, ttl=None, **kwargs):
if table_name not in self._models:
raise self.SchemaException(f'{table_name} does not exist in the database')
# If the ttl has already expired, don't even bother sending the data.
if ttl and ttl < 0:
return
if len(self._batch):
self._models[table_name].batch(self._batch[-1]).ttl(ttl).create(**kwargs)
else:
self._models[table_name].ttl(ttl).create(**kwargs)
@staticmethod
def filter_for_argument(key, value):
key_value = key.split('__')[0]
operator = None if len(key.split('__')) == 1 else key.split('__')[1]
if operator == 'in':
return lambda v, key_value=key_value, value=value: getattr(v, key_value) in value
elif operator == 'gt':
return lambda v, key_value=key_value, value=value: getattr(v, key_value) > value
elif operator == 'gte':
return lambda v, key_value=key_value, value=value: getattr(v, key_value) >= value
elif operator == 'lt':
return lambda v, key_value=key_value, value=value: getattr(v, key_value) < value
elif operator == 'lte':
return lambda v, key_value=key_value, value=value: getattr(v, key_value) <= value
elif operator == 'like':
def regex_filter(v, key_value=key_value, value=value):
regex = re.escape(value)
regex = regex.replace(re.escape('%'), '.*')
return bool(re.match(r'\A' + regex + r'\Z', getattr(v, key_value)))
return regex_filter
elif operator is None or operator is 'in':
return lambda v, key_value=key_value, value=value: getattr(v, key_value) == value
raise self.SelectException('Unrecognized operator {}'.format(operator))
@AssertConnectedDecorator()
def select_from_table(self, table_name, limit=10000, **kwargs):
if table_name not in self._models:
raise self.SchemaException(f'{table_name} does not exist in the database')
create_args = {}
for name, column in self._models[table_name]._columns.items():
if not column.partition_key and not column.primary_key:
continue
did_find_column_name = False
using_range_query = False
for arg, value in kwargs.items():
# cqlengine will use arguments like '<column_name>__like' to indicate that an argument will have some kind of
# operation performed. We're looking for column name here.
if name == arg.split('__')[0] and value is not None:
create_args[arg] = value
did_find_column_name = True
using_range_query = '__' in arg
if not did_find_column_name or using_range_query:
break
# Not all versions of Cassandra support filtering. Since filtering is inefficient in Cassandra anyways, queries which rely
# on filtering should be able to retrieve the data from the database and filter it server-side.
filters = []
for arg, value in kwargs.items():
if arg not in create_args and value is not None:
filters.append(self.filter_for_argument(arg, value))
query_to_be_run = self._models[table_name].objects(**create_args).limit(limit)
# Forces cqlengine to dispatch the query before exiting the function
result = []
for element in query_to_be_run:
for f in filters:
if not f(element):
break
else:
result.append(element)
return result
class SchemaException(RuntimeError):
pass
class SelectException(RuntimeError):
pass