#!/usr/bin/env python3 """ Tests for the DNS validation functionality of the SSL Manager. This module contains tests for checking if a hostname is in public DNS. """ import os import sys import socket import unittest from unittest.mock import patch, MagicMock # Add the src directory to the Python path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) # Import from ssl_manager module from src.ssl_manager import is_hostname_in_public_dns, SSLManager class TestDNSValidation(unittest.TestCase): """Test cases for DNS validation functionality.""" def setUp(self): """Set up test fixtures.""" # Create an SSLManager instance for testing self.ssl_manager = SSLManager() @patch('socket.gethostbyname') def test_hostname_in_public_dns_positive(self, mock_gethostbyname): """Test that a hostname resolving to a public IP returns True.""" # Mock socket.gethostbyname to return a public IP address mock_gethostbyname.return_value = '8.8.8.8' # Google's public DNS server # Call the function result = is_hostname_in_public_dns('example.com') # Verify the result self.assertTrue(result) mock_gethostbyname.assert_called_once_with('example.com') @patch('socket.gethostbyname') def test_hostname_in_public_dns_negative_private_ip(self, mock_gethostbyname): """Test that a hostname resolving to a private IP returns False.""" # Mock socket.gethostbyname to return a private IP address mock_gethostbyname.return_value = '192.168.1.1' # Private IP address # Call the function result = is_hostname_in_public_dns('internal.local') # Verify the result self.assertFalse(result) mock_gethostbyname.assert_called_once_with('internal.local') @patch('socket.gethostbyname') def test_hostname_in_public_dns_negative_loopback(self, mock_gethostbyname): """Test that a hostname resolving to a loopback IP returns False.""" # Mock socket.gethostbyname to return a loopback IP address mock_gethostbyname.return_value = '127.0.0.1' # Loopback IP address # Call the function result = is_hostname_in_public_dns('localhost') # Verify the result self.assertFalse(result) mock_gethostbyname.assert_called_once_with('localhost') @patch('socket.gethostbyname') def test_hostname_in_public_dns_negative_link_local(self, mock_gethostbyname): """Test that a hostname resolving to a link-local IP returns False.""" # Mock socket.gethostbyname to return a link-local IP address mock_gethostbyname.return_value = '169.254.1.1' # Link-local IP address # Call the function result = is_hostname_in_public_dns('link-local.local') # Verify the result self.assertFalse(result) mock_gethostbyname.assert_called_once_with('link-local.local') @patch('socket.gethostbyname') def test_hostname_in_public_dns_negative_gaierror(self, mock_gethostbyname): """Test that a hostname that cannot be resolved returns False.""" # Mock socket.gethostbyname to raise a socket.gaierror mock_gethostbyname.side_effect = socket.gaierror("Name or service not known") # Call the function result = is_hostname_in_public_dns('nonexistent.example.com') # Verify the result self.assertFalse(result) mock_gethostbyname.assert_called_once_with('nonexistent.example.com') @patch('socket.gethostbyname') def test_hostname_in_public_dns_negative_timeout(self, mock_gethostbyname): """Test that a hostname that times out during resolution returns False.""" # Mock socket.gethostbyname to raise a socket.timeout mock_gethostbyname.side_effect = socket.timeout("Timed out") # Call the function result = is_hostname_in_public_dns('slow.example.com') # Verify the result self.assertFalse(result) mock_gethostbyname.assert_called_once_with('slow.example.com') @patch('socket.getdefaulttimeout') @patch('socket.setdefaulttimeout') @patch('socket.gethostbyname') def test_hostname_in_public_dns_with_custom_timeout(self, mock_gethostbyname, mock_set_timeout, mock_get_timeout): """Test that the function uses the provided timeout value.""" # Mock socket.gethostbyname to return a public IP address mock_gethostbyname.return_value = '8.8.8.8' # Google's public DNS server mock_get_timeout.return_value = None # Default timeout # Call the function with a custom timeout result = is_hostname_in_public_dns('example.com', timeout=5.0) # Verify that setdefaulttimeout was called with the correct value mock_set_timeout.assert_any_call(5.0) # Verify the result self.assertTrue(result) mock_gethostbyname.assert_called_once_with('example.com') @patch('src.ssl_manager.is_hostname_in_public_dns') @patch('subprocess.run') def test_generate_letsencrypt_cert_checks_dns(self, mock_run, mock_is_hostname_in_public_dns): """Test that generate_letsencrypt_cert checks if the hostname is in public DNS.""" # Mock is_hostname_in_public_dns to return False mock_is_hostname_in_public_dns.return_value = False # Call the method and expect a ValueError with self.assertRaises(ValueError) as context: self.ssl_manager.generate_letsencrypt_cert( common_name="example.com", email="test@example.com" ) # Verify the error message self.assertIn("not in public DNS", str(context.exception)) # Verify that is_hostname_in_public_dns was called with the correct parameters mock_is_hostname_in_public_dns.assert_called_once_with( "example.com", timeout=self.ssl_manager.connection_timeout ) # Verify that subprocess.run was not called (because we failed at DNS validation) mock_run.assert_not_called() if __name__ == '__main__': unittest.main()