From f95cd86cdccd8f7bf37ff7ea1f9d9f63ef984caf Mon Sep 17 00:00:00 2001
From: Fuxino <dfucini@gmail.com>
Date: Thu, 15 Jun 2023 11:48:23 +0200
Subject: [PATCH] Check backup folder for remote backup

Check that the backup folder exists at the end of the backup when
performing backup over ssh
---
 simple_backup/simple_backup.py | 66 +++++++++++++++++++++++++---------
 1 file changed, 49 insertions(+), 17 deletions(-)

diff --git a/simple_backup/simple_backup.py b/simple_backup/simple_backup.py
index ba2ac91..3ca6605 100755
--- a/simple_backup/simple_backup.py
+++ b/simple_backup/simple_backup.py
@@ -271,9 +271,6 @@ class Backup:
         elif count > 1:
             logger.info('Removed %d backups', count)
 
-        if self._ssh:
-            self._ssh.close()
-
     def find_last_backup(self):
         """Get path of last backup (from last_backup symlink) for rsync --link-dest"""
 
@@ -304,7 +301,12 @@ class Backup:
 
     def _ssh_connect(self):
         ssh = paramiko.SSHClient()
-        ssh.load_system_host_keys()
+
+        try:
+            ssh.load_host_keys(filename=f'{homedir}/.ssh/known_hosts')
+        except FileNotFoundError:
+            logger.warning(f'Cannot find file {homedir}/.ssh/known_hosts')
+
         ssh.set_missing_host_key_policy(paramiko.WarningPolicy())
 
         try:
@@ -318,6 +320,11 @@ class Backup:
                 ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
             else:
                 return None
+        except paramiko.BadHostKeyException as e:
+            logger.critical('Can\'t connect to the server.')
+            logger.critical(e)
+
+            return None
         except paramiko.SSHException:
             pass
 
@@ -436,9 +443,11 @@ class Backup:
                     f'{self._exclude_path} --files-from={self._inputs_path} / "{self._server}{self._output_dir}"'
 
         if euid == 0 and self.ssh_keyfile is not None:
-            rsync = f'{rsync} -e \'ssh -i {self.ssh_keyfile}\''
+            rsync = f'{rsync} -e \'ssh -i {self.ssh_keyfile} -o StrictHostKeyChecking=no\''
         elif self._password_auth and which('sshpass'):
-            rsync = f'{rsync} -e \'sshpass -e ssh -l {self.username}\''
+            rsync = f'{rsync} -e \'sshpass -e ssh -l {self.username} -o StrictHostKeyChecking=no\''
+        else:
+            rsync = f'{rsync} -e \'ssh -o StrictHostKeyChecking=no\''
 
         args = shlex.split(rsync)
 
@@ -467,20 +476,43 @@ class Backup:
         os.remove(self._inputs_path)
         os.remove(self._exclude_path)
 
-        logger.info('Backup completed')
+        if self._remote:
+            _, stdout, _ = self._ssh.exec_command(f'if [ -d "{self._output_dir}" ]; then echo "ok"; fi')
 
-        if self._err_flag:
-            logger.warning('Some errors occurred')
+            output = stdout.read().decode('utf-8').strip()
 
-            try:
-                _notify('Backup finished with errors (check log for details)')
-            except NameError:
-                pass
+            if output == 'ok':
+                logger.info('Backup completed')
+
+                try:
+                    _notify('Backup completed')
+                except NameError:
+                    pass
+            else:
+                logger.error('Backup failed')
+
+                try:
+                    _notify('Backup failed (check log for details)')
+                except NameError:
+                    pass
+
+            if self._ssh:
+                self._ssh.close()
         else:
-            try:
-                _notify('Backup finished')
-            except NameError:
-                pass
+            if self._err_flag:
+                logger.error('Some errors occurred while performing the backup')
+
+                try:
+                    _notify('Some errors occurred while performing the backup. Check log for details')
+                except NameError:
+                    pass
+            else:
+                logger.info('Backup completed')
+
+                try:
+                    _notify('Backup completed')
+                except NameError:
+                    pass
 
 
 def _parse_arguments():