Add type hints plus minor code fixes

This commit is contained in:
daniele 2024-09-28 09:47:33 +02:00
parent 9390cd2de8
commit 45a07205a1
Signed by: fuxino
GPG Key ID: 981A2B2A3BBF5514
3 changed files with 39 additions and 30 deletions

View File

@ -1,3 +1,3 @@
"""Init.""" """Init."""
__version__ = '4.1.2' __version__ = '4.1.5'

View File

@ -2,7 +2,7 @@
[backup] [backup]
# Files and directories to backup. Multiple items can be separated using a comma (','). It is possible to use wildcards (i.e. '*' to match multiple characters and '~' for the user's home directory). # Files and directories to backup. Multiple items can be separated using a comma (','). It is possible to use wildcards (i.e. '*' to match multiple characters and '~' for the user's home directory).
inputs=/home/my_home,/etc inputs=/home/user
# Output directory. # Output directory.
backup_dir=/media/Backup backup_dir=/media/Backup

View File

@ -14,6 +14,7 @@ Classes:
# Import libraries # Import libraries
import sys import sys
import os import os
from typing import Callable, List, Optional, ParamSpec, TypeVar, Union
import warnings import warnings
from functools import wraps from functools import wraps
from shutil import rmtree, which from shutil import rmtree, which
@ -67,29 +68,29 @@ if journal:
j_handler.setFormatter(j_format) j_handler.setFormatter(j_format)
logger.addHandler(j_handler) logger.addHandler(j_handler)
P = ParamSpec('P')
R = TypeVar('R')
def timing(_logger):
def timing(func: Callable[P, R]) -> Callable[P, R]:
"""Decorator to measure execution time of a function """Decorator to measure execution time of a function
Parameters: Parameters:
_logger: Logger object func: Function to decorate
""" """
def decorator_timing(func): @wraps(func)
@wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
def wrapper_timing(*args, **kwargs): start = default_timer()
start = default_timer()
value = func(*args, **kwargs) value = func(*args, **kwargs)
end = default_timer() end = default_timer()
_logger.info(f'Elapsed time: {end - start:.3f} seconds') logger.info('Elapsed time: %.3f seconds', end - start)
return value return value
return wrapper_timing return wrapper
return decorator_timing
class MyFormatter(argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter): class MyFormatter(argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter):
@ -134,8 +135,9 @@ class Backup:
Perform the backup Perform the backup
""" """
def __init__(self, inputs, output, exclude, keep, options, ssh_host=None, ssh_user=None, def __init__(self, inputs: List[str], output: str, exclude: List[str], keep: int, options: str,
ssh_keyfile=None, remote_sudo=False, remove_before=False, verbose=False): ssh_host: Optional[str] = None, ssh_user: Optional[str] = None, ssh_keyfile: Optional[str] = None,
remote_sudo: bool = False, remove_before: bool = False, verbose: bool = False) -> None:
self.inputs = inputs self.inputs = inputs
self.output = output self.output = output
self.exclude = exclude self.exclude = exclude
@ -152,12 +154,12 @@ class Backup:
self._output_dir = '' self._output_dir = ''
self._inputs_path = '' self._inputs_path = ''
self._exclude_path = '' self._exclude_path = ''
self._remote = None self._remote = False
self._ssh = None self._ssh = None
self._password_auth = False self._password_auth = False
self._password = None self._password = None
def check_params(self, homedir=''): def check_params(self, homedir: str = '') -> int:
"""Check if parameters for the backup are valid""" """Check if parameters for the backup are valid"""
if self.inputs is None or len(self.inputs) == 0: if self.inputs is None or len(self.inputs) == 0:
@ -201,7 +203,7 @@ class Backup:
return 0 return 0
# Function to create the actual backup directory # Function to create the actual backup directory
def define_backup_dir(self): def define_backup_dir(self) -> None:
"""Define the actual backup dir""" """Define the actual backup dir"""
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
self._output_dir = f'{self.output}/simple_backup/{now}' self._output_dir = f'{self.output}/simple_backup/{now}'
@ -209,10 +211,12 @@ class Backup:
if self._remote: if self._remote:
self._server = f'{self.ssh_user}@{self.ssh_host}:' self._server = f'{self.ssh_user}@{self.ssh_host}:'
def remove_old_backups(self): def remove_old_backups(self) -> None:
"""Remove old backups if there are more than indicated by 'keep'""" """Remove old backups if there are more than indicated by 'keep'"""
if self._remote: if self._remote:
assert self._ssh is not None
_, stdout, _ = self._ssh.exec_command(f'ls {self.output}/simple_backup') _, stdout, _ = self._ssh.exec_command(f'ls {self.output}/simple_backup')
dirs = stdout.read().decode('utf-8').strip().split('\n') dirs = stdout.read().decode('utf-8').strip().split('\n')
@ -272,7 +276,7 @@ class Backup:
elif count > 1: elif count > 1:
logger.info('Removed %d backups', count) logger.info('Removed %d backups', count)
def find_last_backup(self): def find_last_backup(self) -> None:
"""Get path of last backup (from last_backup symlink) for rsync --link-dest""" """Get path of last backup (from last_backup symlink) for rsync --link-dest"""
if self._remote: if self._remote:
@ -298,7 +302,7 @@ class Backup:
logger.critical('Cannot access the backup directory. Permission denied') logger.critical('Cannot access the backup directory. Permission denied')
try: try:
notify('Backup failed (check log for details)') _notify('Backup failed (check log for details)')
except NameError: except NameError:
pass pass
@ -309,17 +313,18 @@ class Backup:
except IndexError: except IndexError:
logger.info('No previous backups available') logger.info('No previous backups available')
def _ssh_connect(self, homedir=''): def _ssh_connect(self, homedir: str = '') -> paramiko.client.SSHClient:
try: try:
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
except NameError: except NameError:
logger.error('Install paramiko for ssh support') logger.error('Install paramiko for ssh support')
return None return None
try: try:
ssh.load_host_keys(filename=f'{homedir}/.ssh/known_hosts') ssh.load_host_keys(filename=f'{homedir}/.ssh/known_hosts')
except FileNotFoundError: except FileNotFoundError:
logger.warning(f'Cannot find file {homedir}/.ssh/known_hosts') logger.warning('Cannot find file %s/.ssh/known_hosts', homedir)
ssh.set_missing_host_key_policy(paramiko.WarningPolicy()) ssh.set_missing_host_key_policy(paramiko.WarningPolicy())
@ -417,7 +422,7 @@ class Backup:
return ssh return ssh
def _returncode_log(self, returncode): def _returncode_log(self, returncode: int) -> None:
match returncode: match returncode:
case 2: case 2:
logger.error('Rsync error (return code 2) - Protocol incompatibility') logger.error('Rsync error (return code 2) - Protocol incompatibility')
@ -447,8 +452,8 @@ class Backup:
logger.error('Rsync error (return code %d) - Check rsync(1) for details', returncode) logger.error('Rsync error (return code %d) - Check rsync(1) for details', returncode)
# Function to read configuration file # Function to read configuration file
@timing(logger) @timing
def run(self): def run(self) -> int:
"""Perform the backup""" """Perform the backup"""
logger.info('Starting backup...') logger.info('Starting backup...')
@ -477,7 +482,7 @@ class Backup:
logger.info('No existing files or directories specified for backup. Nothing to do') logger.info('No existing files or directories specified for backup. Nothing to do')
try: try:
notify('Backup finished. No files copied') _notify('Backup finished. No files copied')
except NameError: except NameError:
pass pass
@ -518,6 +523,7 @@ class Backup:
args = shlex.split(rsync) args = shlex.split(rsync)
with Popen(args, stdin=PIPE, stdout=PIPE, stderr=STDOUT, shell=False) as p: with Popen(args, stdin=PIPE, stdout=PIPE, stderr=STDOUT, shell=False) as p:
output: Union[bytes, List[str]]
output, _ = p.communicate() output, _ = p.communicate()
try: try:
@ -551,6 +557,8 @@ class Backup:
os.remove(self._exclude_path) os.remove(self._exclude_path)
if self._remote: if self._remote:
assert self._ssh is not None
_, stdout, _ = self._ssh.exec_command(f'if [ -d "{self._output_dir}" ]; then echo "ok"; fi') _, stdout, _ = self._ssh.exec_command(f'if [ -d "{self._output_dir}" ]; then echo "ok"; fi')
output = stdout.read().decode('utf-8').strip() output = stdout.read().decode('utf-8').strip()
@ -600,7 +608,7 @@ def _parse_arguments():
user = os.getenv('SUDO_USER') user = os.getenv('SUDO_USER')
else: else:
user = os.getenv('USER') user = os.getenv('USER')
homedir = os.path.expanduser(f'~{user}') homedir = os.path.expanduser(f'~{user}')
parser = argparse.ArgumentParser(prog='simple_backup', parser = argparse.ArgumentParser(prog='simple_backup',
@ -803,6 +811,7 @@ def simple_backup():
config_args = _read_config(args.config, user) config_args = _read_config(args.config, user)
except (configparser.NoSectionError, configparser.NoOptionError): except (configparser.NoSectionError, configparser.NoOptionError):
logger.critical('Bad configuration file') logger.critical('Bad configuration file')
return 6 return 6
inputs = args.inputs if args.inputs is not None else config_args['inputs'] inputs = args.inputs if args.inputs is not None else config_args['inputs']