#!/usr/bin/env python3 """ Tests for the SSL Manager module. This module contains tests for the SSLManager class and its methods. """ import os import sys import tempfile import unittest from unittest.mock import patch, MagicMock # Add the src directory to the Python path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) from ssl_manager import SSLManager class TestSSLManager(unittest.TestCase): """Test cases for the SSLManager class.""" def setUp(self): """Set up test fixtures.""" # Create a temporary directory for certificates self.temp_dir = tempfile.TemporaryDirectory() self.ssl_manager = SSLManager(cert_dir=self.temp_dir.name) def tearDown(self): """Tear down test fixtures.""" # Clean up the temporary directory self.temp_dir.cleanup() @patch('ssl_manager.socket.socket') @patch('ssl_manager.ssl.create_default_context') def test_check_cert_expiration_valid(self, mock_context, mock_socket): """Test checking a valid certificate expiration.""" # Mock the SSL socket and certificate mock_sock = MagicMock() mock_context.return_value.wrap_socket.return_value = mock_sock # Mock the peer certificate mock_sock.getpeercert.return_value = { 'notAfter': 'Jul 20 12:00:00 2026 GMT', 'issuer': [(('organizationName', 'Test CA'),)], 'subject': [(('commonName', 'example.com'),)] } # Call the method result = self.ssl_manager.check_cert_expiration('example.com') # Verify the result self.assertEqual(result['hostname'], 'example.com') self.assertEqual(result['status'], 'Valid') self.assertIn('days_left', result) self.assertIn('expiration_date', result) @patch('ssl_manager.socket.socket') @patch('ssl_manager.ssl.create_default_context') def test_check_cert_expiration_error(self, mock_context, mock_socket): """Test checking a certificate with an error.""" # Mock the SSL socket to raise an exception mock_sock = MagicMock() mock_context.return_value.wrap_socket.return_value = mock_sock mock_sock.connect.side_effect = Exception("Connection refused") # Call the method result = self.ssl_manager.check_cert_expiration('nonexistent.example.com') # Verify the result self.assertEqual(result['hostname'], 'nonexistent.example.com') self.assertEqual(result['status'], 'Error') self.assertIn('error', result) self.assertEqual(result['error'], 'Connection refused') @patch('ssl_manager.subprocess.run') def test_generate_self_signed_cert(self, mock_run): """Test generating a self-signed certificate.""" # Mock the subprocess calls mock_run.return_value = MagicMock(returncode=0) # Call the method cert_path, key_path = self.ssl_manager.generate_self_signed_cert('test.example.com') # Verify the result self.assertEqual(cert_path, os.path.join(self.temp_dir.name, 'test.example.com.crt')) self.assertEqual(key_path, os.path.join(self.temp_dir.name, 'test.example.com.key')) # Verify subprocess was called twice (once for key, once for cert) self.assertEqual(mock_run.call_count, 2) @patch('ssl_manager.subprocess.run') def test_validate_cert_chain_valid(self, mock_run): """Test validating a valid certificate chain.""" # Create a temporary certificate file cert_path = os.path.join(self.temp_dir.name, 'test.crt') with open(cert_path, 'w') as f: f.write("-----BEGIN CERTIFICATE-----\nDummy Certificate\n-----END CERTIFICATE-----") # Mock the subprocess call mock_run.return_value = MagicMock( returncode=0, stdout="test.crt: OK" ) # Call the method result = self.ssl_manager.validate_cert_chain(cert_path) # Verify the result self.assertTrue(result) mock_run.assert_called_once() @patch('ssl_manager.subprocess.run') def test_validate_cert_chain_invalid(self, mock_run): """Test validating an invalid certificate chain.""" # Create a temporary certificate file cert_path = os.path.join(self.temp_dir.name, 'invalid.crt') with open(cert_path, 'w') as f: f.write("-----BEGIN CERTIFICATE-----\nInvalid Certificate\n-----END CERTIFICATE-----") # Mock the subprocess call mock_run.return_value = MagicMock( returncode=1, stdout="invalid.crt: C = US, O = Example, CN = example.com\nerror 2 at 1 depth lookup: unable to get issuer certificate" ) # Call the method result = self.ssl_manager.validate_cert_chain(cert_path) # Verify the result self.assertFalse(result) mock_run.assert_called_once() if __name__ == '__main__': unittest.main()