319 lines
12 KiB
Python
319 lines
12 KiB
Python
# Copyright 2019-2020 Canonical Ltd.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from datetime import timedelta
|
|
import pickle
|
|
import shutil
|
|
import subprocess
|
|
import sqlite3
|
|
import typing
|
|
|
|
import yaml
|
|
|
|
|
|
def _run(args, **kw):
|
|
cmd = shutil.which(args[0])
|
|
if cmd is None:
|
|
raise FileNotFoundError(args[0])
|
|
return subprocess.run([cmd, *args[1:]], **kw)
|
|
|
|
|
|
class SQLiteStorage:
|
|
|
|
DB_LOCK_TIMEOUT = timedelta(hours=1)
|
|
|
|
def __init__(self, filename):
|
|
# The isolation_level argument is set to None such that the implicit
|
|
# transaction management behavior of the sqlite3 module is disabled.
|
|
self._db = sqlite3.connect(str(filename),
|
|
isolation_level=None,
|
|
timeout=self.DB_LOCK_TIMEOUT.total_seconds())
|
|
self._setup()
|
|
|
|
def _setup(self):
|
|
# Make sure that the database is locked until the connection is closed,
|
|
# not until the transaction ends.
|
|
self._db.execute("PRAGMA locking_mode=EXCLUSIVE")
|
|
c = self._db.execute("BEGIN")
|
|
c.execute("SELECT count(name) FROM sqlite_master WHERE type='table' AND name='snapshot'")
|
|
if c.fetchone()[0] == 0:
|
|
# Keep in mind what might happen if the process dies somewhere below.
|
|
# The system must not be rendered permanently broken by that.
|
|
self._db.execute("CREATE TABLE snapshot (handle TEXT PRIMARY KEY, data BLOB)")
|
|
self._db.execute('''
|
|
CREATE TABLE notice (
|
|
sequence INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
event_path TEXT,
|
|
observer_path TEXT,
|
|
method_name TEXT)
|
|
''')
|
|
self._db.commit()
|
|
|
|
def close(self):
|
|
self._db.close()
|
|
|
|
def commit(self):
|
|
self._db.commit()
|
|
|
|
# There's commit but no rollback. For abort to be supported, we'll need logic that
|
|
# can rollback decisions made by third-party code in terms of the internal state
|
|
# of objects that have been snapshotted, and hooks to let them know about it and
|
|
# take the needed actions to undo their logic until the last snapshot.
|
|
# This is doable but will increase significantly the chances for mistakes.
|
|
|
|
def save_snapshot(self, handle_path: str, snapshot_data: typing.Any) -> None:
|
|
"""Part of the Storage API, persist a snapshot data under the given handle.
|
|
|
|
Args:
|
|
handle_path: The string identifying the snapshot.
|
|
snapshot_data: The data to be persisted. (as returned by Object.snapshot()). This
|
|
might be a dict/tuple/int, but must only contain 'simple' python types.
|
|
"""
|
|
# Use pickle for serialization, so the value remains portable.
|
|
raw_data = pickle.dumps(snapshot_data)
|
|
self._db.execute("REPLACE INTO snapshot VALUES (?, ?)", (handle_path, raw_data))
|
|
|
|
def load_snapshot(self, handle_path: str) -> typing.Any:
|
|
"""Part of the Storage API, retrieve a snapshot that was previously saved.
|
|
|
|
Args:
|
|
handle_path: The string identifying the snapshot.
|
|
Raises:
|
|
NoSnapshotError: if there is no snapshot for the given handle_path.
|
|
"""
|
|
c = self._db.cursor()
|
|
c.execute("SELECT data FROM snapshot WHERE handle=?", (handle_path,))
|
|
row = c.fetchone()
|
|
if row:
|
|
return pickle.loads(row[0])
|
|
raise NoSnapshotError(handle_path)
|
|
|
|
def drop_snapshot(self, handle_path: str):
|
|
"""Part of the Storage API, remove a snapshot that was previously saved.
|
|
|
|
Dropping a snapshot that doesn't exist is treated as a no-op.
|
|
"""
|
|
self._db.execute("DELETE FROM snapshot WHERE handle=?", (handle_path,))
|
|
|
|
def list_snapshots(self) -> typing.Generator[str, None, None]:
|
|
"""Return the name of all snapshots that are currently saved."""
|
|
c = self._db.cursor()
|
|
c.execute("SELECT handle FROM snapshot")
|
|
while True:
|
|
rows = c.fetchmany()
|
|
if not rows:
|
|
break
|
|
for row in rows:
|
|
yield row[0]
|
|
|
|
def save_notice(self, event_path: str, observer_path: str, method_name: str) -> None:
|
|
"""Part of the Storage API, record an notice (event and observer)"""
|
|
self._db.execute('INSERT INTO notice VALUES (NULL, ?, ?, ?)',
|
|
(event_path, observer_path, method_name))
|
|
|
|
def drop_notice(self, event_path: str, observer_path: str, method_name: str) -> None:
|
|
"""Part of the Storage API, remove a notice that was previously recorded."""
|
|
self._db.execute('''
|
|
DELETE FROM notice
|
|
WHERE event_path=?
|
|
AND observer_path=?
|
|
AND method_name=?
|
|
''', (event_path, observer_path, method_name))
|
|
|
|
def notices(self, event_path: typing.Optional[str]) ->\
|
|
typing.Generator[typing.Tuple[str, str, str], None, None]:
|
|
"""Part of the Storage API, return all notices that begin with event_path.
|
|
|
|
Args:
|
|
event_path: If supplied, will only yield events that match event_path. If not
|
|
supplied (or None/'') will return all events.
|
|
Returns:
|
|
Iterable of (event_path, observer_path, method_name) tuples
|
|
"""
|
|
if event_path:
|
|
c = self._db.execute('''
|
|
SELECT event_path, observer_path, method_name
|
|
FROM notice
|
|
WHERE event_path=?
|
|
ORDER BY sequence
|
|
''', (event_path,))
|
|
else:
|
|
c = self._db.execute('''
|
|
SELECT event_path, observer_path, method_name
|
|
FROM notice
|
|
ORDER BY sequence
|
|
''')
|
|
while True:
|
|
rows = c.fetchmany()
|
|
if not rows:
|
|
break
|
|
for row in rows:
|
|
yield tuple(row)
|
|
|
|
|
|
class JujuStorage:
|
|
""""Storing the content tracked by the Framework in Juju.
|
|
|
|
This uses :class:`_JujuStorageBackend` to interact with state-get/state-set
|
|
as the way to store state for the framework and for components.
|
|
"""
|
|
|
|
NOTICE_KEY = "#notices#"
|
|
|
|
def __init__(self, backend: '_JujuStorageBackend' = None):
|
|
self._backend = backend
|
|
if backend is None:
|
|
self._backend = _JujuStorageBackend()
|
|
|
|
def close(self):
|
|
return
|
|
|
|
def commit(self):
|
|
return
|
|
|
|
def save_snapshot(self, handle_path: str, snapshot_data: typing.Any) -> None:
|
|
self._backend.set(handle_path, snapshot_data)
|
|
|
|
def load_snapshot(self, handle_path):
|
|
try:
|
|
content = self._backend.get(handle_path)
|
|
except KeyError:
|
|
raise NoSnapshotError(handle_path)
|
|
return content
|
|
|
|
def drop_snapshot(self, handle_path):
|
|
self._backend.delete(handle_path)
|
|
|
|
def save_notice(self, event_path: str, observer_path: str, method_name: str):
|
|
notice_list = self._load_notice_list()
|
|
notice_list.append([event_path, observer_path, method_name])
|
|
self._save_notice_list(notice_list)
|
|
|
|
def drop_notice(self, event_path: str, observer_path: str, method_name: str):
|
|
notice_list = self._load_notice_list()
|
|
notice_list.remove([event_path, observer_path, method_name])
|
|
self._save_notice_list(notice_list)
|
|
|
|
def notices(self, event_path: str):
|
|
notice_list = self._load_notice_list()
|
|
for row in notice_list:
|
|
if row[0] != event_path:
|
|
continue
|
|
yield tuple(row)
|
|
|
|
def _load_notice_list(self) -> typing.List[typing.Tuple[str]]:
|
|
try:
|
|
notice_list = self._backend.get(self.NOTICE_KEY)
|
|
except KeyError:
|
|
return []
|
|
if notice_list is None:
|
|
return []
|
|
return notice_list
|
|
|
|
def _save_notice_list(self, notices: typing.List[typing.Tuple[str]]) -> None:
|
|
self._backend.set(self.NOTICE_KEY, notices)
|
|
|
|
|
|
class _SimpleLoader(getattr(yaml, 'CSafeLoader', yaml.SafeLoader)):
|
|
"""Handle a couple basic python types.
|
|
|
|
yaml.SafeLoader can handle all the basic int/float/dict/set/etc that we want. The only one
|
|
that it *doesn't* handle is tuples. We don't want to support arbitrary types, so we just
|
|
subclass SafeLoader and add tuples back in.
|
|
"""
|
|
# Taken from the example at:
|
|
# https://stackoverflow.com/questions/9169025/how-can-i-add-a-python-tuple-to-a-yaml-file-using-pyyaml
|
|
|
|
construct_python_tuple = yaml.Loader.construct_python_tuple
|
|
|
|
|
|
_SimpleLoader.add_constructor(
|
|
u'tag:yaml.org,2002:python/tuple',
|
|
_SimpleLoader.construct_python_tuple)
|
|
|
|
|
|
class _SimpleDumper(getattr(yaml, 'CSafeDumper', yaml.SafeDumper)):
|
|
"""Add types supported by 'marshal'
|
|
|
|
YAML can support arbitrary types, but that is generally considered unsafe (like pickle). So
|
|
we want to only support dumping out types that are safe to load.
|
|
"""
|
|
|
|
|
|
_SimpleDumper.represent_tuple = yaml.Dumper.represent_tuple
|
|
_SimpleDumper.add_representer(tuple, _SimpleDumper.represent_tuple)
|
|
|
|
|
|
def juju_backend_available() -> bool:
|
|
"""Check if Juju state storage is available."""
|
|
p = shutil.which('state-get')
|
|
return p is not None
|
|
|
|
|
|
class _JujuStorageBackend:
|
|
"""Implements the interface from the Operator framework to Juju's state-get/set/etc."""
|
|
|
|
def set(self, key: str, value: typing.Any) -> None:
|
|
"""Set a key to a given value.
|
|
|
|
Args:
|
|
key: The string key that will be used to find the value later
|
|
value: Arbitrary content that will be returned by get().
|
|
Raises:
|
|
CalledProcessError: if 'state-set' returns an error code.
|
|
"""
|
|
# default_flow_style=None means that it can use Block for
|
|
# complex types (types that have nested types) but use flow
|
|
# for simple types (like an array). Not all versions of PyYAML
|
|
# have the same default style.
|
|
encoded_value = yaml.dump(value, Dumper=_SimpleDumper, default_flow_style=None)
|
|
content = yaml.dump(
|
|
{key: encoded_value}, encoding='utf8', default_style='|',
|
|
default_flow_style=False,
|
|
Dumper=_SimpleDumper)
|
|
_run(["state-set", "--file", "-"], input=content, check=True)
|
|
|
|
def get(self, key: str) -> typing.Any:
|
|
"""Get the bytes value associated with a given key.
|
|
|
|
Args:
|
|
key: The string key that will be used to find the value
|
|
Raises:
|
|
CalledProcessError: if 'state-get' returns an error code.
|
|
"""
|
|
# We don't capture stderr here so it can end up in debug logs.
|
|
p = _run(["state-get", key], stdout=subprocess.PIPE, check=True, universal_newlines=True)
|
|
if p.stdout == '' or p.stdout == '\n':
|
|
raise KeyError(key)
|
|
return yaml.load(p.stdout, Loader=_SimpleLoader)
|
|
|
|
def delete(self, key: str) -> None:
|
|
"""Remove a key from being tracked.
|
|
|
|
Args:
|
|
key: The key to stop storing
|
|
Raises:
|
|
CalledProcessError: if 'state-delete' returns an error code.
|
|
"""
|
|
_run(["state-delete", key], check=True)
|
|
|
|
|
|
class NoSnapshotError(Exception):
|
|
|
|
def __init__(self, handle_path):
|
|
self.handle_path = handle_path
|
|
|
|
def __str__(self):
|
|
return 'no snapshot data found for {} object'.format(self.handle_path)
|