diff --git a/.junie/guidelines.md b/.junie/guidelines.md index fe6f0f8..20cefc6 100644 --- a/.junie/guidelines.md +++ b/.junie/guidelines.md @@ -38,6 +38,7 @@ The SSL Manager uses a configuration file (`config.json`) to store default setti "connection_timeout": 3.0, "default_validity_days": 365, "key_size": 2048, + "debug": false, "unifi": { "host": "unifi.example.com", "username": "admin", @@ -47,6 +48,12 @@ The SSL Manager uses a configuration file (`config.json`) to store default setti "ssh_username": "root", "ssh_password": "", "ssh_key_path": "~/.ssh/id_rsa" + }, + "letsencrypt": { + "email": "admin@example.com", + "validation_method": "standalone", + "use_staging": false, + "agree_tos": true } } ``` @@ -57,6 +64,7 @@ Configuration options: - `connection_timeout`: Timeout in seconds for SSL connections - `default_validity_days`: Default validity period in days for generated certificates - `key_size`: Key size in bits for generated certificates +- `debug`: Enable debug logging with line numbers and file names (default: false) - `unifi`: UniFi device connection parameters - `host`: Hostname or IP address of the UniFi device - `username`: Username for authenticating with the UniFi device @@ -66,6 +74,11 @@ Configuration options: - `ssh_username`: Username for SSH authentication with the UniFi device - `ssh_password`: Password for SSH authentication (leave empty to use SSH key) - `ssh_key_path`: Path to the SSH private key file for authentication +- `letsencrypt`: Let's Encrypt certificate settings + - `email`: Email address for Let's Encrypt registration and important notifications + - `validation_method`: Method to use for domain validation (standalone, webroot, dns) + - `use_staging`: Whether to use Let's Encrypt's staging environment for testing (true/false) + - `agree_tos`: Whether to automatically agree to the Terms of Service (true/false) ### Usage @@ -73,6 +86,7 @@ The SSL Manager provides three main commands. All commands support the following - `--config`: Path to the config file (default: config.json) - `--cert-dir`: Directory to store certificates (overrides config) +- `--debug`: Enable debug logging with line numbers and file names 1. **Check Certificate Expiration**: ```bash @@ -80,17 +94,107 @@ The SSL Manager provides three main commands. All commands support the following ``` The `--port` option overrides the `default_port` from the config file. -2. **Generate Self-Signed Certificate**: +2. **Generate Certificate**: ```bash - python src/ssl_manager.py generate example.com [--days DAYS] + python src/ssl_manager.py generate [COMMON_NAME] [--type TYPE] [--days DAYS] [--email EMAIL] [--validation-method METHOD] [--staging|--production] ``` - The `--days` option overrides the `default_validity_days` from the config file. + The `COMMON_NAME` parameter is optional. If not provided, the UniFi host from the config file will be used. This ensures that the certificate is valid for the UniFi device. + + Options: + - `--type`: Type of certificate to generate (self-signed or letsencrypt, default: letsencrypt) + - `--days`: Days valid (overrides config, only for self-signed certificates) + - `--email`: Email address for Let's Encrypt registration (overrides config) + - `--validation-method`: Method to use for domain validation (standalone, webroot, dns) + - `--staging`: Use Let's Encrypt's staging environment (for testing) + - `--production`: Use Let's Encrypt's production environment 3. **Validate Certificate Chain**: ```bash python src/ssl_manager.py validate path/to/certificate.crt [--ca-path path/to/ca.crt] ``` +### Let's Encrypt Validation Methods + +When generating certificates with Let's Encrypt, you need to prove that you control the domain. The SSL Manager supports three validation methods: + +1. **Standalone** (`--validation-method standalone`): + - Starts a temporary web server on port 80 to respond to Let's Encrypt's validation requests + - Requires port 80 to be available and accessible from the internet + - Best for servers where you don't have a web server running + - **Requires the hostname to be in public DNS** with an A/AAAA record pointing to your server + +2. **Webroot** (`--validation-method webroot`): + - Uses an existing web server to serve validation files + - Requires write access to the web server's document root (default: /var/www/html) + - Best for servers with an existing web server + - **Requires the hostname to be in public DNS** with an A/AAAA record pointing to your server + +3. **DNS** (`--validation-method dns`): + - Uses DNS TXT records for validation + - Requires manual intervention to add DNS records + - Best for validating wildcard certificates or when port 80 is not accessible + - **Requires the hostname to be in public DNS** where you can add TXT records + +By default, the SSL Manager uses Let's Encrypt's production environment, which issues trusted certificates. For testing purposes, use the `--staging` flag to use Let's Encrypt's staging environment, which has higher rate limits but issues untrusted certificates. Once you've confirmed everything works with the staging environment, you can remove the `--staging` flag to use the production environment. + +### Public DNS Requirements + +**Yes, the hostname must be in a public DNS for Let's Encrypt certificates.** Let's Encrypt needs to verify that you control the domain before issuing a certificate. The SSL Manager automatically checks if the hostname is in public DNS before attempting to generate a Let's Encrypt certificate and stops with an error if it's not. + +This verification process requires: + +1. For **standalone** and **webroot** validation methods: + - The domain must have a public DNS record (A or AAAA) pointing to your server + - Your server must be publicly accessible on port 80 + - Let's Encrypt servers must be able to reach your server over the internet + +2. For **DNS** validation method: + - The domain must have public DNS records where you can add TXT records + - You don't need a publicly accessible server, but you need control over the domain's DNS records + +#### Alternative Approaches for Private Networks + +If you're using the SSL Manager in a private network where the hostname isn't in public DNS, consider these alternatives: + +1. **Self-signed certificates**: Use `--type self-signed` for internal use only (browsers will show warnings) +2. **Private CA**: Set up your own Certificate Authority for your internal network +3. **Split DNS**: Configure your DNS to resolve the domain internally while also having it in public DNS +4. **Domain with DNS API**: Use a domain you control with DNS API support for automated DNS validation + +### Certificate Verification + +The SSL Manager automatically verifies the current certificate for the UniFi host after initialization. When you run any command, the SSL Manager will: + +1. Check if a certificate file exists for the UniFi host in the certificate directory +2. If it exists, validate it using OpenSSL +3. Display the verification status + +Example output: +``` +Certificate for unifi.example.com: + Status: Valid + Path: /home/user/.ssl-certs/unifi.example.com.crt + Message: Certificate for unifi.example.com is valid +``` + +Possible status values: +- **Valid**: The certificate exists and is valid +- **Invalid**: The certificate exists but is invalid (e.g., expired, self-signed, or not trusted) +- **Missing**: No certificate file was found for the UniFi host +- **Not configured**: No UniFi host is configured in the config file + +### Host and Certificate Validity + +For a certificate to be valid for a UniFi device, the Common Name (CN) in the certificate must match the hostname of the device. This is why the SSL Manager uses the UniFi host from the config file as the default common_name when generating certificates. + +When you access your UniFi device through a web browser, the browser checks that the hostname in the URL matches the Common Name in the certificate. If they don't match, the browser will display a security warning. + +For example: +- If your UniFi device is accessed at `https://udm-se.example.com` +- The Common Name in the certificate should be `udm-se.example.com` + +By configuring the `host` field in the config file and using it as the default common_name, the SSL Manager ensures that the generated certificate will be valid for your UniFi device. + ## Testing Information ### Running Tests @@ -218,8 +322,8 @@ To automate certificate updates using cron: # Activate the virtual environment source .venv/bin/activate - # Run the SSL Manager to update certificates - python src/ssl_manager.py generate your-unifi-device.example.com --days 90 + # Run the SSL Manager to update certificates with Let's Encrypt + python src/ssl_manager.py generate your-unifi-device.example.com --type letsencrypt --email admin@example.com --validation-method standalone --production # Additional commands to deploy the certificate to the UniFi device can be added here ``` @@ -259,8 +363,41 @@ To automate certificate updates using cron: Make sure you're running tests from the project root directory. +4. **Hostname Not in Public DNS Error** + + When generating a Let's Encrypt certificate, you may see an error like: + ``` + Error generating Let's Encrypt certificate: Hostname example.com is not in public DNS. Let's Encrypt requires the hostname to be in public DNS. + ``` + + This means the hostname you're trying to use doesn't resolve to a public IP address. To fix this: + - Verify that the hostname has a public DNS record (A or AAAA) pointing to your server + - Check that the DNS record has propagated (this can take up to 48 hours) + - If you're using a private hostname, consider using a self-signed certificate instead with `--type self-signed` + - For testing purposes, you can use a hostname that is already in public DNS + ### Debugging -- Set the `SSL_DEBUG=1` environment variable for verbose output +- Use the `--debug` flag to enable detailed logging with line numbers and file names: + ```bash + python src/ssl_manager.py --debug check example.com + ``` + +- Set the `debug` option to `true` in the config.json file to always enable debug logging: + ```json + { + "cert_dir": "~/.ssl-certs", + "default_port": 443, + "debug": true, + "connection_timeout": 3.0 + } + ``` + +- Debug logs include: + - Line numbers and file names for each log message + - Detailed information about each operation + - Command execution details + - Error messages with stack traces + - Check the OpenSSL version with `openssl version` - Verify certificate paths are correct and accessible \ No newline at end of file diff --git a/config.json b/config.json index d37324f..22ac87f 100644 --- a/config.json +++ b/config.json @@ -4,6 +4,7 @@ "connection_timeout": 3.0, "default_validity_days": 365, "key_size": 2048, + "debug": false, "unifi": { "host": "udm-se.mgeppert.com", "username": "SSLCertificate", @@ -14,12 +15,19 @@ "ssh_password": "RH6X64FAAiE7CrcV84lQ", "ssh_key_path": "~/.ssh/id_rsa" }, + "letsencrypt": { + "email": "mgeppert1@gmail.com", + "validation_method": "standalone", + "use_staging": false, + "agree_tos": true + }, "comments": { "cert_dir": "Directory where certificates and keys will be stored", "default_port": "Default port to use when checking certificate expiration", "connection_timeout": "Timeout in seconds for SSL connections", "default_validity_days": "Default validity period in days for generated certificates", "key_size": "Key size in bits for generated certificates", + "debug": "Enable debug logging with line numbers and file names (default: false)", "unifi": "UniFi device connection parameters", "unifi.host": "Hostname or IP address of the UniFi device", "unifi.username": "Username for authenticating with the UniFi device", @@ -28,6 +36,11 @@ "unifi.ssh_port": "SSH port for the UniFi device (default: 22)", "unifi.ssh_username": "Username for SSH authentication with the UniFi device", "unifi.ssh_password": "Password for SSH authentication (leave empty to use SSH key)", - "unifi.ssh_key_path": "Path to the SSH private key file for authentication" + "unifi.ssh_key_path": "Path to the SSH private key file for authentication", + "letsencrypt": "Let's Encrypt certificate settings", + "letsencrypt.email": "Email address for Let's Encrypt registration and important notifications", + "letsencrypt.validation_method": "Method to use for domain validation (standalone, webroot, dns)", + "letsencrypt.use_staging": "Whether to use Let's Encrypt's staging environment for testing (true/false)", + "letsencrypt.agree_tos": "Whether to automatically agree to the Terms of Service (true/false)" } } \ No newline at end of file diff --git a/src/ssl_manager.py b/src/ssl_manager.py index 70da71b..078f015 100644 --- a/src/ssl_manager.py +++ b/src/ssl_manager.py @@ -15,7 +15,100 @@ import datetime import argparse import subprocess import json -from typing import Dict, Tuple, Optional, List, Any +import logging +import inspect +import sys +import ipaddress +from typing import Dict, Tuple, Optional, List, Any, Union + + +def setup_logging(debug: bool = False) -> None: + """ + Set up logging configuration. + + Args: + debug: Whether to enable debug logging (default: False) + """ + # Reset root logger + root = logging.getLogger() + if root.handlers: + for handler in root.handlers: + root.removeHandler(handler) + + # Create console handler and set level + console = logging.StreamHandler(sys.stdout) + + if debug: + # Configure logging with line numbers and file names + root.setLevel(logging.DEBUG) + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + console.setLevel(logging.DEBUG) + else: + # Configure basic logging for warnings and errors + root.setLevel(logging.INFO) + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + console.setLevel(logging.INFO) + + # Set formatter and add handler to root logger + console.setFormatter(formatter) + root.addHandler(console) + + if debug: + logging.debug("Debug logging enabled") + + +def is_hostname_in_public_dns(hostname: str, timeout: float = 3.0) -> bool: + """ + Check if a hostname is in public DNS by resolving it to an IP address + and checking if the IP is public (not in private IP ranges). + + Args: + hostname: The hostname to check + timeout: Timeout in seconds for DNS resolution (default: 3.0) + + Returns: + True if the hostname resolves to a public IP address, False otherwise + """ + logging.debug(f"Checking if hostname {hostname} is in public DNS") + + # Set socket timeout + original_timeout = socket.getdefaulttimeout() + socket.setdefaulttimeout(timeout) + + try: + # Resolve hostname to IP address + ip_str = socket.gethostbyname(hostname) + logging.debug(f"Hostname {hostname} resolved to IP address {ip_str}") + + # Check if IP is private + ip = ipaddress.ip_address(ip_str) + is_public = not (ip.is_private or ip.is_loopback or ip.is_link_local) + + if is_public: + logging.debug(f"IP address {ip_str} is public") + return True + else: + logging.debug(f"IP address {ip_str} is private (not in public DNS)") + return False + + except socket.gaierror as e: + logging.error(f"Failed to resolve hostname {hostname}: {str(e)}") + return False + except socket.timeout: + logging.error(f"Timeout while resolving hostname {hostname}") + return False + except Exception as e: + logging.error(f"Error checking if hostname {hostname} is in public DNS: {str(e)}") + return False + finally: + # Restore original timeout + socket.setdefaulttimeout(original_timeout) def load_config(config_path: str = "config.json") -> Dict[str, Any]: @@ -28,12 +121,15 @@ def load_config(config_path: str = "config.json") -> Dict[str, Any]: Returns: Dictionary containing configuration values """ + logging.debug(f"Loading configuration from {config_path}") + default_config = { "cert_dir": "~/.ssl-certs", "default_port": 443, "connection_timeout": 3.0, "default_validity_days": 365, "key_size": 2048, + "debug": False, "unifi": { "host": "", "username": "", @@ -43,28 +139,50 @@ def load_config(config_path: str = "config.json") -> Dict[str, Any]: "ssh_username": "", "ssh_password": "", "ssh_key_path": "~/.ssh/id_rsa" + }, + "letsencrypt": { + "email": "", + "validation_method": "standalone", + "use_staging": True, + "agree_tos": True } } + logging.debug("Default configuration initialized") try: + logging.debug(f"Attempting to open config file: {config_path}") with open(config_path, 'r') as f: config = json.load(f) + logging.debug(f"Config file loaded successfully") + # Remove comments section if present if "comments" in config: + logging.debug("Removing comments section from config") del config["comments"] + # Update default config with values from file + logging.debug("Updating default config with values from file") for key in default_config: if key in config: if isinstance(default_config[key], dict) and isinstance(config[key], dict): # Handle nested dictionaries (like unifi) + logging.debug(f"Processing nested dictionary: {key}") for nested_key in default_config[key]: if nested_key in config[key]: + logging.debug(f" Setting {key}.{nested_key}") default_config[key][nested_key] = config[key][nested_key] else: + logging.debug(f"Setting {key}") default_config[key] = config[key] - except (FileNotFoundError, json.JSONDecodeError): - # Use default config if file not found or invalid - pass + except FileNotFoundError: + logging.warning(f"Config file not found: {config_path}, using default configuration") + except json.JSONDecodeError as e: + logging.error(f"Error parsing config file: {e}, using default configuration") + except Exception as e: + logging.error(f"Unexpected error loading config: {e}, using default configuration") + + # Log the final configuration + logging.debug(f"Final configuration: {default_config}") return default_config @@ -80,30 +198,63 @@ class SSLManager: cert_dir: Directory to store certificates (default: None) config_path: Path to the config file (default: config.json) """ + logging.debug(f"Initializing SSLManager with config_path={config_path}, cert_dir={cert_dir}") + # Load configuration + logging.debug(f"Loading configuration from {config_path}") self.config = load_config(config_path) # Use cert_dir from parameters if provided, otherwise from config self.cert_dir = cert_dir or os.path.expanduser(self.config["cert_dir"]) - os.makedirs(self.cert_dir, exist_ok=True) + logging.debug(f"Using certificate directory: {self.cert_dir}") + + # Create certificate directory if it doesn't exist + if not os.path.exists(self.cert_dir): + logging.debug(f"Creating certificate directory: {self.cert_dir}") + os.makedirs(self.cert_dir, exist_ok=True) # Store other config values self.default_port = self.config["default_port"] self.connection_timeout = self.config["connection_timeout"] self.default_validity_days = self.config["default_validity_days"] self.key_size = self.config["key_size"] + logging.debug(f"Loaded config values: default_port={self.default_port}, " + f"connection_timeout={self.connection_timeout}, " + f"default_validity_days={self.default_validity_days}, " + f"key_size={self.key_size}") # Store UniFi device connection parameters self.unifi_host = self.config["unifi"]["host"] self.unifi_username = self.config["unifi"]["username"] self.unifi_password = self.config["unifi"]["password"] self.unifi_site = self.config["unifi"]["site"] + logging.debug(f"Loaded UniFi connection parameters: host={self.unifi_host}, " + f"username={self.unifi_username}, " + f"site={self.unifi_site}") # Store UniFi device SSH parameters self.unifi_ssh_port = self.config["unifi"]["ssh_port"] self.unifi_ssh_username = self.config["unifi"]["ssh_username"] self.unifi_ssh_password = self.config["unifi"]["ssh_password"] self.unifi_ssh_key_path = self.config["unifi"]["ssh_key_path"] + logging.debug(f"Loaded UniFi SSH parameters: port={self.unifi_ssh_port}, " + f"username={self.unifi_ssh_username}, " + f"key_path={self.unifi_ssh_key_path}") + + # Store Let's Encrypt settings + self.letsencrypt_email = self.config["letsencrypt"]["email"] + self.letsencrypt_validation_method = self.config["letsencrypt"]["validation_method"] + self.letsencrypt_use_staging = self.config["letsencrypt"]["use_staging"] + self.letsencrypt_agree_tos = self.config["letsencrypt"]["agree_tos"] + logging.debug(f"Loaded Let's Encrypt settings: email={self.letsencrypt_email}, " + f"validation_method={self.letsencrypt_validation_method}, " + f"use_staging={self.letsencrypt_use_staging}, " + f"agree_tos={self.letsencrypt_agree_tos}") + + # Verify current certificate after initialization + logging.debug("Verifying current certificate after initialization") + self.cert_verification = self.verify_current_certificate() + logging.info(f"Certificate verification status: {self.cert_verification['status']}") def check_cert_expiration(self, hostname: str, port: int = None) -> Dict: """ @@ -118,6 +269,7 @@ class SSLManager: """ # Use provided port or default from config port = port or self.default_port + logging.debug(f"Checking certificate for {hostname}:{port}") context = ssl.create_default_context() conn = context.wrap_socket( @@ -127,20 +279,25 @@ class SSLManager: # Use timeout from config conn.settimeout(self.connection_timeout) + logging.debug(f"Connection timeout set to {self.connection_timeout} seconds") try: + logging.debug(f"Connecting to {hostname}:{port}") conn.connect((hostname, port)) cert = conn.getpeercert() + logging.debug("Connection established, certificate retrieved") # Parse expiration date expiration_date = datetime.datetime.strptime( cert['notAfter'], '%b %d %H:%M:%S %Y %Z' ) + logging.debug(f"Certificate expiration date: {expiration_date}") # Calculate days until expiration days_left = (expiration_date - datetime.datetime.now()).days + logging.debug(f"Days until expiration: {days_left}") - return { + result = { 'hostname': hostname, 'port': port, 'issuer': dict(x[0] for x in cert['issuer']), @@ -149,7 +306,10 @@ class SSLManager: 'days_left': days_left, 'status': 'Valid' if days_left > 0 else 'Expired' } + logging.debug(f"Certificate status: {result['status']}") + return result except Exception as e: + logging.error(f"Error checking certificate for {hostname}:{port}: {str(e)}") return { 'hostname': hostname, 'port': port, @@ -158,6 +318,7 @@ class SSLManager: } finally: conn.close() + logging.debug("Connection closed") def generate_self_signed_cert( self, @@ -176,27 +337,49 @@ class SSLManager: """ # Use provided days_valid or default from config days_valid = days_valid or self.default_validity_days + logging.debug(f"Generating self-signed certificate for {common_name} valid for {days_valid} days") cert_path = os.path.join(self.cert_dir, f"{common_name}.crt") key_path = os.path.join(self.cert_dir, f"{common_name}.key") + logging.debug(f"Certificate path: {cert_path}") + logging.debug(f"Private key path: {key_path}") # Generate private key using key size from config - subprocess.run([ + logging.debug(f"Generating private key with size {self.key_size} bits") + key_cmd = [ 'openssl', 'genrsa', '-out', key_path, str(self.key_size) - ], check=True) + ] + logging.debug(f"Running command: {' '.join(key_cmd)}") + try: + subprocess.run(key_cmd, check=True, capture_output=True, text=True) + logging.debug("Private key generated successfully") + except subprocess.CalledProcessError as e: + logging.error(f"Error generating private key: {e}") + logging.debug(f"Command output: {e.stdout}\n{e.stderr}") + raise # Generate certificate - subprocess.run([ + logging.debug(f"Generating certificate with subject /CN={common_name}") + cert_cmd = [ 'openssl', 'req', '-new', '-x509', '-key', key_path, '-out', cert_path, '-days', str(days_valid), '-subj', f"/CN={common_name}" - ], check=True) + ] + logging.debug(f"Running command: {' '.join(cert_cmd)}") + try: + subprocess.run(cert_cmd, check=True, capture_output=True, text=True) + logging.debug("Certificate generated successfully") + except subprocess.CalledProcessError as e: + logging.error(f"Error generating certificate: {e}") + logging.debug(f"Command output: {e.stdout}\n{e.stderr}") + raise + logging.info(f"Generated self-signed certificate for {common_name}") return cert_path, key_path def validate_cert_chain(self, cert_path: str, ca_path: Optional[str] = None) -> bool: @@ -210,15 +393,38 @@ class SSLManager: Returns: True if valid, False otherwise """ + logging.debug(f"Validating certificate chain for {cert_path}") + + # Check if certificate file exists + if not os.path.isfile(cert_path): + logging.error(f"Certificate file not found: {cert_path}") + return False + + # Check if CA file exists if provided + if ca_path and not os.path.isfile(ca_path): + logging.error(f"CA certificate file not found: {ca_path}") + return False + cmd = ['openssl', 'verify'] if ca_path: + logging.debug(f"Using CA certificate: {ca_path}") cmd.extend(['-CAfile', ca_path]) + else: + logging.debug("Using system CA certificates") cmd.append(cert_path) + logging.debug(f"Running command: {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=True, text=True) - return result.returncode == 0 and 'OK' in result.stdout + + if result.returncode == 0 and 'OK' in result.stdout: + logging.debug(f"Certificate validation successful: {result.stdout.strip()}") + return True + else: + logging.error(f"Certificate validation failed: {result.stderr.strip()}") + logging.debug(f"Command output: {result.stdout}\n{result.stderr}") + return False def get_unifi_connection_params(self) -> Dict[str, str]: """ @@ -227,12 +433,15 @@ class SSLManager: Returns: Dictionary containing the UniFi device connection parameters """ - return { + logging.debug("Getting UniFi device connection parameters") + params = { "host": self.unifi_host, "username": self.unifi_username, "password": self.unifi_password, "site": self.unifi_site } + logging.debug(f"UniFi connection parameters: {params}") + return params def get_unifi_ssh_params(self) -> Dict[str, str]: """ @@ -241,13 +450,221 @@ class SSLManager: Returns: Dictionary containing the UniFi device SSH parameters """ - return { + logging.debug("Getting UniFi device SSH parameters") + params = { "host": self.unifi_host, "port": self.unifi_ssh_port, "username": self.unifi_ssh_username, "password": self.unifi_ssh_password, "key_path": self.unifi_ssh_key_path } + logging.debug(f"UniFi SSH parameters: {params}") + return params + + def generate_letsencrypt_cert( + self, + common_name: str, + email: str = None, + validation_method: str = None, + use_staging: bool = None, + agree_tos: bool = None + ) -> Tuple[str, str]: + """ + Generate a certificate using Let's Encrypt. + + Args: + common_name: Common Name (CN) for the certificate (domain name) + email: Email address for Let's Encrypt registration (default: from config) + validation_method: Method to use for domain validation (default: from config) + use_staging: Whether to use Let's Encrypt's staging environment (default: from config) + agree_tos: Whether to automatically agree to the Terms of Service (default: from config) + + Returns: + Tuple of (cert_path, key_path) + """ + # Use provided values or defaults from config + email = email or self.letsencrypt_email + validation_method = validation_method or self.letsencrypt_validation_method + use_staging = use_staging if use_staging is not None else self.letsencrypt_use_staging + agree_tos = agree_tos if agree_tos is not None else self.letsencrypt_agree_tos + + logging.debug(f"Generating Let's Encrypt certificate for {common_name}") + logging.debug(f"Using email: {email}") + logging.debug(f"Using validation method: {validation_method}") + logging.debug(f"Using staging environment: {use_staging}") + logging.debug(f"Automatically agree to ToS: {agree_tos}") + + # Check if email is provided (required by Let's Encrypt) + if not email: + error_msg = "Email address is required for Let's Encrypt registration" + logging.error(error_msg) + raise ValueError(error_msg) + + # Check if hostname is in public DNS + logging.debug(f"Checking if hostname {common_name} is in public DNS") + if not is_hostname_in_public_dns(common_name, timeout=self.connection_timeout): + error_msg = f"Hostname {common_name} is not in public DNS. Let's Encrypt requires the hostname to be in public DNS." + logging.error(error_msg) + raise ValueError(error_msg) + logging.debug(f"Hostname {common_name} is in public DNS") + + # Define paths for certificate and key + cert_path = os.path.join(self.cert_dir, f"{common_name}.crt") + key_path = os.path.join(self.cert_dir, f"{common_name}.key") + logging.debug(f"Certificate path: {cert_path}") + logging.debug(f"Private key path: {key_path}") + + # Build certbot command + cmd = ['certbot', 'certonly'] + + # Add validation method + if validation_method == 'standalone': + cmd.append('--standalone') + elif validation_method == 'webroot': + cmd.extend(['--webroot', '--webroot-path', '/var/www/html']) + elif validation_method == 'dns': + cmd.append('--manual') + cmd.append('--preferred-challenges=dns') + else: + error_msg = f"Unsupported validation method: {validation_method}" + logging.error(error_msg) + raise ValueError(error_msg) + + # Add domain + cmd.extend(['-d', common_name]) + + # Add email + cmd.extend(['-m', email]) + + # Add staging flag if needed + if use_staging: + cmd.append('--test-cert') + + # Add agree to ToS flag if needed + if agree_tos: + cmd.append('--agree-tos') + + # Add non-interactive flag + cmd.append('-n') + + # Specify config directory in our cert_dir to avoid requiring sudo + config_dir = os.path.join(self.cert_dir, '.config') + work_dir = os.path.join(self.cert_dir, '.work') + logs_dir = os.path.join(self.cert_dir, '.logs') + + # Create directories if they don't exist + for directory in [config_dir, work_dir, logs_dir]: + if not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) + + # Add config directory options + cmd.extend(['--config-dir', config_dir]) + cmd.extend(['--work-dir', work_dir]) + cmd.extend(['--logs-dir', logs_dir]) + + # Log the command + logging.debug(f"Running command: {' '.join(cmd)}") + + try: + # Run certbot + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + logging.debug(f"Command output: {result.stdout}") + logging.info(f"Generated Let's Encrypt certificate for {common_name}") + + # Find and copy the certificate and key files from certbot's directory structure + certbot_live_dir = os.path.join(config_dir, 'live', common_name) + if not os.path.exists(certbot_live_dir): + error_msg = f"Certificate directory not found: {certbot_live_dir}" + logging.error(error_msg) + raise FileNotFoundError(error_msg) + + # Copy the certificate (fullchain.pem) to our expected location + fullchain_path = os.path.join(certbot_live_dir, 'fullchain.pem') + if not os.path.isfile(fullchain_path): + error_msg = f"Certificate file not found: {fullchain_path}" + logging.error(error_msg) + raise FileNotFoundError(error_msg) + + logging.debug(f"Copying certificate from {fullchain_path} to {cert_path}") + with open(fullchain_path, 'rb') as src, open(cert_path, 'wb') as dst: + dst.write(src.read()) + + # Copy the private key (privkey.pem) to our expected location + privkey_path = os.path.join(certbot_live_dir, 'privkey.pem') + if not os.path.isfile(privkey_path): + error_msg = f"Private key file not found: {privkey_path}" + logging.error(error_msg) + raise FileNotFoundError(error_msg) + + logging.debug(f"Copying private key from {privkey_path} to {key_path}") + with open(privkey_path, 'rb') as src, open(key_path, 'wb') as dst: + dst.write(src.read()) + + return cert_path, key_path + + except subprocess.CalledProcessError as e: + logging.error(f"Error generating Let's Encrypt certificate: {e}") + logging.debug(f"Command output: {e.stdout}\n{e.stderr}") + raise + + def verify_current_certificate(self) -> Dict[str, Any]: + """ + Verify if the current certificate for the UniFi host exists and is valid. + + This method checks if a certificate file exists for the UniFi host in the + certificate directory and validates it if it exists. + + Returns: + Dictionary with certificate verification status and details + """ + logging.debug(f"Verifying current certificate for UniFi host: {self.unifi_host}") + + # Skip verification if no UniFi host is configured + if not self.unifi_host: + logging.warning("No UniFi host configured, skipping certificate verification") + return { + 'exists': False, + 'valid': False, + 'status': 'Not configured', + 'message': 'No UniFi host configured' + } + + # Construct the expected certificate path + cert_path = os.path.join(self.cert_dir, f"{self.unifi_host}.crt") + logging.debug(f"Expected certificate path: {cert_path}") + + # Check if certificate file exists + if not os.path.isfile(cert_path): + logging.warning(f"Certificate file not found: {cert_path}") + return { + 'exists': False, + 'valid': False, + 'status': 'Missing', + 'message': f"Certificate file not found: {cert_path}" + } + + # Validate the certificate + logging.debug(f"Certificate file found, validating: {cert_path}") + is_valid = self.validate_cert_chain(cert_path) + + if is_valid: + logging.info(f"Certificate for {self.unifi_host} is valid") + return { + 'exists': True, + 'valid': True, + 'status': 'Valid', + 'message': f"Certificate for {self.unifi_host} is valid", + 'cert_path': cert_path + } + else: + logging.warning(f"Certificate for {self.unifi_host} is invalid") + return { + 'exists': True, + 'valid': False, + 'status': 'Invalid', + 'message': f"Certificate for {self.unifi_host} is invalid", + 'cert_path': cert_path + } def main(): @@ -257,6 +674,7 @@ def main(): # Global arguments parser.add_argument('--config', help='Path to config file (default: config.json)', default='config.json') parser.add_argument('--cert-dir', help='Directory to store certificates (overrides config)') + parser.add_argument('--debug', action='store_true', help='Enable debug logging with line numbers and file names') subparsers = parser.add_subparsers(dest='command', help='Command to run') @@ -265,10 +683,19 @@ def main(): check_parser.add_argument('hostname', help='Hostname to check') check_parser.add_argument('--port', type=int, help='Port (overrides config)') - # Generate self-signed certificate command - gen_parser = subparsers.add_parser('generate', help='Generate self-signed certificate') - gen_parser.add_argument('common_name', help='Common Name (CN) for the certificate') + # Generate certificate command + gen_parser = subparsers.add_parser('generate', help='Generate certificate') + gen_parser.add_argument('common_name', nargs='?', help='Common Name (CN) for the certificate (defaults to UniFi host from config)') gen_parser.add_argument('--days', type=int, help='Days valid (overrides config)') + gen_parser.add_argument('--type', choices=['self-signed', 'letsencrypt'], default='letsencrypt', + help='Type of certificate to generate (default: letsencrypt)') + + # Let's Encrypt specific options + gen_parser.add_argument('--email', help='Email address for Let\'s Encrypt registration') + gen_parser.add_argument('--validation-method', choices=['standalone', 'webroot', 'dns'], + help='Method to use for domain validation') + gen_parser.add_argument('--staging', action='store_true', help='Use Let\'s Encrypt staging environment') + gen_parser.add_argument('--production', action='store_true', help='Use Let\'s Encrypt production environment') # Validate certificate command validate_parser = subparsers.add_parser('validate', help='Validate certificate chain') @@ -276,31 +703,117 @@ def main(): validate_parser.add_argument('--ca-path', help='Path to CA certificate') args = parser.parse_args() + + # Load configuration to get debug setting + config = load_config(args.config) + + # Use debug flag from command line or from config file + debug_enabled = args.debug or config.get("debug", False) + + # Set up logging based on debug flag + setup_logging(debug_enabled) + + # Log debug information about arguments and config + if debug_enabled: + logging.debug(f"Command-line arguments: {args}") + logging.debug(f"Using config file: {args.config}") + logging.debug(f"Debug enabled via {'command line' if args.debug else 'config file'}") + ssl_manager = SSLManager(cert_dir=args.cert_dir, config_path=args.config) + # Display certificate verification status + if ssl_manager.unifi_host: + print(f"Certificate for {ssl_manager.unifi_host}:") + print(f" Status: {ssl_manager.cert_verification['status']}") + if ssl_manager.cert_verification['exists']: + print(f" Path: {ssl_manager.cert_verification.get('cert_path', 'N/A')}") + print(f" Message: {ssl_manager.cert_verification['message']}") + if args.command == 'check': + logging.debug(f"Checking certificate expiration for {args.hostname}:{args.port or ssl_manager.default_port}") result = ssl_manager.check_cert_expiration(args.hostname, args.port) if result['status'] == 'Error': - print(f"Error checking {args.hostname}:{args.port}: {result['error']}") + logging.error(f"Error checking {args.hostname}:{args.port or ssl_manager.default_port}: {result['error']}") + print(f"Error checking {args.hostname}:{args.port or ssl_manager.default_port}: {result['error']}") else: + logging.info(f"Certificate for {result['hostname']} expires in {result['days_left']} days") + logging.debug(f"Certificate details: {result}") print(f"Certificate for {result['hostname']}:") print(f" Status: {result['status']}") print(f" Expires: {result['expiration_date']} ({result['days_left']} days left)") print(f" Issuer: {result['issuer'].get('organizationName', 'N/A')}") elif args.command == 'generate': - cert_path, key_path = ssl_manager.generate_self_signed_cert( - args.common_name, args.days - ) - print(f"Generated self-signed certificate:") - print(f" Certificate: {cert_path}") - print(f" Private Key: {key_path}") + # Determine certificate type + cert_type = args.type + + # Handle conflicting options + if args.staging and args.production: + logging.error("Cannot use both --staging and --production options") + print("Error: Cannot use both --staging and --production options") + sys.exit(1) + + # Determine staging setting + use_staging = None + if args.staging: + use_staging = True + elif args.production: + use_staging = False + + # Use UniFi host as common_name if not provided + common_name = args.common_name + if common_name is None: + common_name = ssl_manager.unifi_host + if not common_name: + logging.error("No common_name provided and no UniFi host configured in config file") + print("Error: No common_name provided and no UniFi host configured in config file") + print("Please provide a common_name or configure a UniFi host in the config file") + sys.exit(1) + logging.info(f"Using UniFi host '{common_name}' as common_name") + + if cert_type == 'self-signed': + logging.debug(f"Generating self-signed certificate for {common_name} valid for {args.days or ssl_manager.default_validity_days} days") + cert_path, key_path = ssl_manager.generate_self_signed_cert( + common_name, args.days + ) + logging.info(f"Generated self-signed certificate for {common_name}") + logging.debug(f"Certificate path: {cert_path}") + logging.debug(f"Private key path: {key_path}") + print(f"Generated self-signed certificate:") + print(f" Certificate: {cert_path}") + print(f" Private Key: {key_path}") + + elif cert_type == 'letsencrypt': + logging.debug(f"Generating Let's Encrypt certificate for {common_name}") + try: + cert_path, key_path = ssl_manager.generate_letsencrypt_cert( + common_name=common_name, + email=args.email, + validation_method=args.validation_method, + use_staging=use_staging, + agree_tos=True # Always agree to ToS from command line + ) + logging.info(f"Generated Let's Encrypt certificate for {common_name}") + logging.debug(f"Certificate path: {cert_path}") + logging.debug(f"Private key path: {key_path}") + print(f"Generated Let's Encrypt certificate:") + print(f" Certificate: {cert_path}") + print(f" Private Key: {key_path}") + except Exception as e: + logging.error(f"Error generating Let's Encrypt certificate: {str(e)}") + print(f"Error generating Let's Encrypt certificate: {str(e)}") + sys.exit(1) elif args.command == 'validate': + logging.debug(f"Validating certificate chain for {args.cert_path}") + if args.ca_path: + logging.debug(f"Using CA certificate: {args.ca_path}") is_valid = ssl_manager.validate_cert_chain(args.cert_path, args.ca_path) + logging.info(f"Certificate validation {'successful' if is_valid else 'failed'}") print(f"Certificate validation {'successful' if is_valid else 'failed'}") else: + logging.debug("No command specified, showing help") parser.print_help() diff --git a/test_no_host.json b/test_no_host.json new file mode 100644 index 0000000..eb6b218 --- /dev/null +++ b/test_no_host.json @@ -0,0 +1,24 @@ +{ + "cert_dir": "~/.ssl-certs", + "default_port": 443, + "connection_timeout": 3.0, + "default_validity_days": 365, + "key_size": 2048, + "debug": false, + "unifi": { + "host": "", + "username": "SSLCertificate", + "password": "cYu2E1OWt0XseVf9j5ML", + "site": "default", + "ssh_port": 22, + "ssh_username": "root", + "ssh_password": "RH6X64FAAiE7CrcV84lQ", + "ssh_key_path": "~/.ssh/id_rsa" + }, + "letsencrypt": { + "email": "mgeppert1@gmail.com", + "validation_method": "standalone", + "use_staging": true, + "agree_tos": true + } +} \ No newline at end of file diff --git a/tests/test_cert_verification.py b/tests/test_cert_verification.py new file mode 100644 index 0000000..ef3da09 --- /dev/null +++ b/tests/test_cert_verification.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +Tests for the certificate verification functionality of the SSL Manager. + +This module contains tests for verifying certificates after initialization. +""" + +import os +import sys +import json +import tempfile +import unittest +from unittest.mock import patch, MagicMock + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) + +from ssl_manager import SSLManager + + +class TestCertVerification(unittest.TestCase): + """Test cases for certificate verification functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a temporary directory for test files + self.temp_dir = tempfile.TemporaryDirectory() + + # Sample config for testing + self.test_config = { + "cert_dir": self.temp_dir.name, + "default_port": 8443, + "connection_timeout": 5.0, + "default_validity_days": 730, + "key_size": 4096, + "unifi": { + "host": "test.unifi.local", + "username": "testuser", + "password": "testpass", + "site": "testsite", + "ssh_port": 2222, + "ssh_username": "sshuser", + "ssh_password": "sshpass", + "ssh_key_path": "~/test-ssh-key" + } + } + + # Create a temporary config file + self.config_path = os.path.join(self.temp_dir.name, "test_config.json") + with open(self.config_path, 'w') as f: + json.dump(self.test_config, f) + + def tearDown(self): + """Tear down test fixtures.""" + # Clean up the temporary directory + self.temp_dir.cleanup() + + def test_verify_current_certificate_missing(self): + """Test verification when certificate is missing.""" + # Create an SSLManager with the test config + ssl_manager = SSLManager(config_path=self.config_path) + + # Verify that cert_verification is set and indicates missing certificate + self.assertIsNotNone(ssl_manager.cert_verification) + self.assertEqual(ssl_manager.cert_verification['status'], 'Missing') + self.assertFalse(ssl_manager.cert_verification['exists']) + self.assertFalse(ssl_manager.cert_verification['valid']) + + @patch('ssl_manager.SSLManager.validate_cert_chain') + def test_verify_current_certificate_valid(self, mock_validate): + """Test verification when certificate is valid.""" + # Mock the validate_cert_chain method to return True + mock_validate.return_value = True + + # Create a dummy certificate file + cert_path = os.path.join(self.temp_dir.name, "test.unifi.local.crt") + with open(cert_path, 'w') as f: + f.write("-----BEGIN CERTIFICATE-----\nDummy Certificate\n-----END CERTIFICATE-----") + + # Create an SSLManager with the test config + ssl_manager = SSLManager(config_path=self.config_path) + + # Verify that cert_verification is set and indicates valid certificate + self.assertIsNotNone(ssl_manager.cert_verification) + self.assertEqual(ssl_manager.cert_verification['status'], 'Valid') + self.assertTrue(ssl_manager.cert_verification['exists']) + self.assertTrue(ssl_manager.cert_verification['valid']) + self.assertEqual(ssl_manager.cert_verification['cert_path'], cert_path) + + # Verify that validate_cert_chain was called with the correct path + mock_validate.assert_called_once_with(cert_path) + + @patch('ssl_manager.SSLManager.validate_cert_chain') + def test_verify_current_certificate_invalid(self, mock_validate): + """Test verification when certificate is invalid.""" + # Mock the validate_cert_chain method to return False + mock_validate.return_value = False + + # Create a dummy certificate file + cert_path = os.path.join(self.temp_dir.name, "test.unifi.local.crt") + with open(cert_path, 'w') as f: + f.write("-----BEGIN CERTIFICATE-----\nDummy Certificate\n-----END CERTIFICATE-----") + + # Create an SSLManager with the test config + ssl_manager = SSLManager(config_path=self.config_path) + + # Verify that cert_verification is set and indicates invalid certificate + self.assertIsNotNone(ssl_manager.cert_verification) + self.assertEqual(ssl_manager.cert_verification['status'], 'Invalid') + self.assertTrue(ssl_manager.cert_verification['exists']) + self.assertFalse(ssl_manager.cert_verification['valid']) + self.assertEqual(ssl_manager.cert_verification['cert_path'], cert_path) + + # Verify that validate_cert_chain was called with the correct path + mock_validate.assert_called_once_with(cert_path) + + def test_verify_current_certificate_no_host(self): + """Test verification when no UniFi host is configured.""" + # Create a config with no UniFi host + config_no_host = self.test_config.copy() + config_no_host["unifi"]["host"] = "" + + # Create a temporary config file + config_path_no_host = os.path.join(self.temp_dir.name, "test_config_no_host.json") + with open(config_path_no_host, 'w') as f: + json.dump(config_no_host, f) + + # Create an SSLManager with the modified config + ssl_manager = SSLManager(config_path=config_path_no_host) + + # Verify that cert_verification is set and indicates not configured + self.assertIsNotNone(ssl_manager.cert_verification) + self.assertEqual(ssl_manager.cert_verification['status'], 'Not configured') + self.assertFalse(ssl_manager.cert_verification['exists']) + self.assertFalse(ssl_manager.cert_verification['valid']) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_debug.py b/tests/test_debug.py new file mode 100644 index 0000000..86bfaea --- /dev/null +++ b/tests/test_debug.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +Tests for the debug functionality of the SSL Manager. + +This module contains tests for the debug logging functionality. +""" + +import os +import sys +import json +import tempfile +import unittest +import logging +from unittest.mock import patch, MagicMock +from io import StringIO + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from src.ssl_manager import setup_logging, load_config, SSLManager + + +class TestDebugLogging(unittest.TestCase): + """Test cases for debug logging functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a temporary directory for test files + self.temp_dir = tempfile.TemporaryDirectory() + + # Sample config for testing + self.test_config = { + "cert_dir": "~/test-certs", + "default_port": 8443, + "connection_timeout": 5.0, + "default_validity_days": 730, + "key_size": 4096, + "debug": True + } + + # Create a temporary config file + self.config_path = os.path.join(self.temp_dir.name, "test_config.json") + with open(self.config_path, 'w') as f: + json.dump(self.test_config, f) + + def tearDown(self): + """Tear down test fixtures.""" + # Clean up the temporary directory + self.temp_dir.cleanup() + + def test_setup_logging_debug_enabled(self): + """Test that setup_logging configures logging correctly when debug is enabled.""" + # Capture log output + log_capture = StringIO() + handler = logging.StreamHandler(log_capture) + + # Patch the logging.basicConfig to use our handler + with patch('logging.basicConfig') as mock_basic_config: + # Call setup_logging with debug=True + setup_logging(True) + + # Verify that logging was configured with DEBUG level + mock_basic_config.assert_called_once() + args, kwargs = mock_basic_config.call_args + self.assertEqual(kwargs['level'], logging.DEBUG) + + # Verify that the format includes filename and line number + self.assertIn('%(filename)s:%(lineno)d', kwargs['format']) + + def test_setup_logging_debug_disabled(self): + """Test that setup_logging configures logging correctly when debug is disabled.""" + # Patch the logging.basicConfig + with patch('logging.basicConfig') as mock_basic_config: + # Call setup_logging with debug=False + setup_logging(False) + + # Verify that logging was configured with INFO level + mock_basic_config.assert_called_once() + args, kwargs = mock_basic_config.call_args + self.assertEqual(kwargs['level'], logging.INFO) + + # Verify that the format does not include filename and line number + self.assertNotIn('%(filename)s:%(lineno)d', kwargs['format']) + + def test_load_config_with_debug_enabled(self): + """Test loading a configuration file with debug enabled.""" + # Load the config + config = load_config(self.config_path) + + # Verify the debug setting + self.assertTrue(config["debug"]) + + @patch('sys.stdout', new_callable=StringIO) + @patch('sys.stderr', new_callable=StringIO) + def test_debug_output_format(self, mock_stderr, mock_stdout): + """Test that debug output includes line numbers and file names.""" + # Set up logging to capture output + setup_logging(True) + + # Generate a debug log message + logging.debug("Test debug message") + + # Get the captured output + output = mock_stderr.getvalue() + + # Verify that the output includes the filename and line number + self.assertIn('test_debug.py:', output) + self.assertIn(' - Test debug message', output) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_dns_validation.py b/tests/test_dns_validation.py new file mode 100644 index 0000000..bb5def4 --- /dev/null +++ b/tests/test_dns_validation.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Tests for the DNS validation functionality of the SSL Manager. + +This module contains tests for checking if a hostname is in public DNS. +""" + +import os +import sys +import socket +import unittest +from unittest.mock import patch, MagicMock + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) + +# Import from ssl_manager module +from src.ssl_manager import is_hostname_in_public_dns, SSLManager + + +class TestDNSValidation(unittest.TestCase): + """Test cases for DNS validation functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Create an SSLManager instance for testing + self.ssl_manager = SSLManager() + + @patch('socket.gethostbyname') + def test_hostname_in_public_dns_positive(self, mock_gethostbyname): + """Test that a hostname resolving to a public IP returns True.""" + # Mock socket.gethostbyname to return a public IP address + mock_gethostbyname.return_value = '8.8.8.8' # Google's public DNS server + + # Call the function + result = is_hostname_in_public_dns('example.com') + + # Verify the result + self.assertTrue(result) + mock_gethostbyname.assert_called_once_with('example.com') + + @patch('socket.gethostbyname') + def test_hostname_in_public_dns_negative_private_ip(self, mock_gethostbyname): + """Test that a hostname resolving to a private IP returns False.""" + # Mock socket.gethostbyname to return a private IP address + mock_gethostbyname.return_value = '192.168.1.1' # Private IP address + + # Call the function + result = is_hostname_in_public_dns('internal.local') + + # Verify the result + self.assertFalse(result) + mock_gethostbyname.assert_called_once_with('internal.local') + + @patch('socket.gethostbyname') + def test_hostname_in_public_dns_negative_loopback(self, mock_gethostbyname): + """Test that a hostname resolving to a loopback IP returns False.""" + # Mock socket.gethostbyname to return a loopback IP address + mock_gethostbyname.return_value = '127.0.0.1' # Loopback IP address + + # Call the function + result = is_hostname_in_public_dns('localhost') + + # Verify the result + self.assertFalse(result) + mock_gethostbyname.assert_called_once_with('localhost') + + @patch('socket.gethostbyname') + def test_hostname_in_public_dns_negative_link_local(self, mock_gethostbyname): + """Test that a hostname resolving to a link-local IP returns False.""" + # Mock socket.gethostbyname to return a link-local IP address + mock_gethostbyname.return_value = '169.254.1.1' # Link-local IP address + + # Call the function + result = is_hostname_in_public_dns('link-local.local') + + # Verify the result + self.assertFalse(result) + mock_gethostbyname.assert_called_once_with('link-local.local') + + @patch('socket.gethostbyname') + def test_hostname_in_public_dns_negative_gaierror(self, mock_gethostbyname): + """Test that a hostname that cannot be resolved returns False.""" + # Mock socket.gethostbyname to raise a socket.gaierror + mock_gethostbyname.side_effect = socket.gaierror("Name or service not known") + + # Call the function + result = is_hostname_in_public_dns('nonexistent.example.com') + + # Verify the result + self.assertFalse(result) + mock_gethostbyname.assert_called_once_with('nonexistent.example.com') + + @patch('socket.gethostbyname') + def test_hostname_in_public_dns_negative_timeout(self, mock_gethostbyname): + """Test that a hostname that times out during resolution returns False.""" + # Mock socket.gethostbyname to raise a socket.timeout + mock_gethostbyname.side_effect = socket.timeout("Timed out") + + # Call the function + result = is_hostname_in_public_dns('slow.example.com') + + # Verify the result + self.assertFalse(result) + mock_gethostbyname.assert_called_once_with('slow.example.com') + + @patch('socket.getdefaulttimeout') + @patch('socket.setdefaulttimeout') + @patch('socket.gethostbyname') + def test_hostname_in_public_dns_with_custom_timeout(self, mock_gethostbyname, mock_set_timeout, mock_get_timeout): + """Test that the function uses the provided timeout value.""" + # Mock socket.gethostbyname to return a public IP address + mock_gethostbyname.return_value = '8.8.8.8' # Google's public DNS server + mock_get_timeout.return_value = None # Default timeout + + # Call the function with a custom timeout + result = is_hostname_in_public_dns('example.com', timeout=5.0) + + # Verify that setdefaulttimeout was called with the correct value + mock_set_timeout.assert_any_call(5.0) + + # Verify the result + self.assertTrue(result) + mock_gethostbyname.assert_called_once_with('example.com') + + @patch('src.ssl_manager.is_hostname_in_public_dns') + @patch('subprocess.run') + def test_generate_letsencrypt_cert_checks_dns(self, mock_run, mock_is_hostname_in_public_dns): + """Test that generate_letsencrypt_cert checks if the hostname is in public DNS.""" + # Mock is_hostname_in_public_dns to return False + mock_is_hostname_in_public_dns.return_value = False + + # Call the method and expect a ValueError + with self.assertRaises(ValueError) as context: + self.ssl_manager.generate_letsencrypt_cert( + common_name="example.com", + email="test@example.com" + ) + + # Verify the error message + self.assertIn("not in public DNS", str(context.exception)) + + # Verify that is_hostname_in_public_dns was called with the correct parameters + mock_is_hostname_in_public_dns.assert_called_once_with( + "example.com", + timeout=self.ssl_manager.connection_timeout + ) + + # Verify that subprocess.run was not called (because we failed at DNS validation) + mock_run.assert_not_called() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_letsencrypt.py b/tests/test_letsencrypt.py new file mode 100644 index 0000000..9baae9b --- /dev/null +++ b/tests/test_letsencrypt.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +Tests for the Let's Encrypt functionality of the SSL Manager. + +This module contains tests for generating certificates using Let's Encrypt. +""" + +import os +import sys +import tempfile +import unittest +import subprocess +from unittest.mock import patch, MagicMock + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) + +from ssl_manager import SSLManager + + +class TestLetsEncrypt(unittest.TestCase): + """Test cases for Let's Encrypt certificate generation.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a temporary directory for test files + self.temp_dir = tempfile.TemporaryDirectory() + + # Sample config for testing + self.test_config = { + "cert_dir": self.temp_dir.name, + "default_port": 8443, + "connection_timeout": 5.0, + "default_validity_days": 730, + "key_size": 4096, + "letsencrypt": { + "email": "test@example.com", + "validation_method": "standalone", + "use_staging": True, + "agree_tos": True + } + } + + # Create a temporary config file + self.config_path = os.path.join(self.temp_dir.name, "test_config.json") + with open(self.config_path, 'w') as f: + import json + json.dump(self.test_config, f) + + # Create an SSLManager with the test config + self.ssl_manager = SSLManager(config_path=self.config_path) + + # Create directories for certbot + self.config_dir = os.path.join(self.temp_dir.name, '.config') + self.work_dir = os.path.join(self.temp_dir.name, '.work') + self.logs_dir = os.path.join(self.temp_dir.name, '.logs') + self.live_dir = os.path.join(self.config_dir, 'live', 'test.example.com') + + # Create directories + os.makedirs(self.live_dir, exist_ok=True) + + # Create dummy certificate and key files + self.fullchain_path = os.path.join(self.live_dir, 'fullchain.pem') + self.privkey_path = os.path.join(self.live_dir, 'privkey.pem') + + with open(self.fullchain_path, 'w') as f: + f.write("-----BEGIN CERTIFICATE-----\nDummy Certificate\n-----END CERTIFICATE-----") + + with open(self.privkey_path, 'w') as f: + f.write("-----BEGIN PRIVATE KEY-----\nDummy Private Key\n-----END PRIVATE KEY-----") + + def tearDown(self): + """Tear down test fixtures.""" + # Clean up the temporary directory + self.temp_dir.cleanup() + + @patch('subprocess.run') + def test_generate_letsencrypt_cert(self, mock_run): + """Test generating a Let's Encrypt certificate.""" + # Mock the subprocess.run call + mock_run.return_value = MagicMock(returncode=0, stdout="Certificate issued successfully") + + # Call the method + cert_path, key_path = self.ssl_manager.generate_letsencrypt_cert( + common_name="test.example.com", + email="test@example.com", + validation_method="standalone", + use_staging=True, + agree_tos=True + ) + + # Verify that subprocess.run was called with the correct arguments + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + + # Verify the command includes the expected arguments + cmd = args[0] + self.assertEqual(cmd[0], 'certbot') + self.assertEqual(cmd[1], 'certonly') + self.assertIn('--standalone', cmd) + self.assertIn('-d', cmd) + self.assertIn('test.example.com', cmd) + self.assertIn('-m', cmd) + self.assertIn('test@example.com', cmd) + self.assertIn('--test-cert', cmd) + self.assertIn('--agree-tos', cmd) + self.assertIn('-n', cmd) + + # Verify the paths + expected_cert_path = os.path.join(self.temp_dir.name, 'test.example.com.crt') + expected_key_path = os.path.join(self.temp_dir.name, 'test.example.com.key') + self.assertEqual(cert_path, expected_cert_path) + self.assertEqual(key_path, expected_key_path) + + # Verify the certificate and key files were created + self.assertTrue(os.path.isfile(cert_path)) + self.assertTrue(os.path.isfile(key_path)) + + # Verify the content of the certificate and key files + with open(cert_path, 'r') as f: + cert_content = f.read() + with open(key_path, 'r') as f: + key_content = f.read() + + self.assertIn("Dummy Certificate", cert_content) + self.assertIn("Dummy Private Key", key_content) + + @patch('subprocess.run') + def test_generate_letsencrypt_cert_with_defaults(self, mock_run): + """Test generating a Let's Encrypt certificate with default values from config.""" + # Mock the subprocess.run call + mock_run.return_value = MagicMock(returncode=0, stdout="Certificate issued successfully") + + # Call the method with minimal arguments + cert_path, key_path = self.ssl_manager.generate_letsencrypt_cert( + common_name="test.example.com" + ) + + # Verify that subprocess.run was called with the correct arguments + mock_run.assert_called_once() + args, kwargs = mock_run.call_args + + # Verify the command includes the expected arguments + cmd = args[0] + self.assertEqual(cmd[0], 'certbot') + self.assertEqual(cmd[1], 'certonly') + self.assertIn('--standalone', cmd) # Default from config + self.assertIn('-d', cmd) + self.assertIn('test.example.com', cmd) + self.assertIn('-m', cmd) + self.assertIn('test@example.com', cmd) # Default from config + self.assertIn('--test-cert', cmd) # Default from config + self.assertIn('--agree-tos', cmd) # Default from config + self.assertIn('-n', cmd) + + # Verify the paths + expected_cert_path = os.path.join(self.temp_dir.name, 'test.example.com.crt') + expected_key_path = os.path.join(self.temp_dir.name, 'test.example.com.key') + self.assertEqual(cert_path, expected_cert_path) + self.assertEqual(key_path, expected_key_path) + + @patch('subprocess.run') + def test_generate_letsencrypt_cert_error(self, mock_run): + """Test error handling when generating a Let's Encrypt certificate.""" + # Mock the subprocess.run call to raise an exception + mock_run.side_effect = subprocess.CalledProcessError( + returncode=1, + cmd=['certbot', 'certonly'], + output="An error occurred", + stderr="Certificate issuance failed" + ) + + # Call the method and expect an exception + with self.assertRaises(subprocess.CalledProcessError): + self.ssl_manager.generate_letsencrypt_cert( + common_name="test.example.com", + email="test@example.com" + ) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_ssl_manager.py b/tests/test_ssl_manager.py index 61472d9..04b5f19 100644 --- a/tests/test_ssl_manager.py +++ b/tests/test_ssl_manager.py @@ -92,6 +92,11 @@ class TestSSLManager(unittest.TestCase): @patch('ssl_manager.subprocess.run') def test_validate_cert_chain_valid(self, mock_run): """Test validating a valid certificate chain.""" + # Create a temporary certificate file + cert_path = os.path.join(self.temp_dir.name, 'test.crt') + with open(cert_path, 'w') as f: + f.write("-----BEGIN CERTIFICATE-----\nDummy Certificate\n-----END CERTIFICATE-----") + # Mock the subprocess call mock_run.return_value = MagicMock( returncode=0, @@ -99,7 +104,7 @@ class TestSSLManager(unittest.TestCase): ) # Call the method - result = self.ssl_manager.validate_cert_chain('test.crt') + result = self.ssl_manager.validate_cert_chain(cert_path) # Verify the result self.assertTrue(result) @@ -108,14 +113,19 @@ class TestSSLManager(unittest.TestCase): @patch('ssl_manager.subprocess.run') def test_validate_cert_chain_invalid(self, mock_run): """Test validating an invalid certificate chain.""" + # Create a temporary certificate file + cert_path = os.path.join(self.temp_dir.name, 'invalid.crt') + with open(cert_path, 'w') as f: + f.write("-----BEGIN CERTIFICATE-----\nInvalid Certificate\n-----END CERTIFICATE-----") + # Mock the subprocess call mock_run.return_value = MagicMock( returncode=1, - stdout="test.crt: C = US, O = Example, CN = example.com\nerror 2 at 1 depth lookup: unable to get issuer certificate" + stdout="invalid.crt: C = US, O = Example, CN = example.com\nerror 2 at 1 depth lookup: unable to get issuer certificate" ) # Call the method - result = self.ssl_manager.validate_cert_chain('test.crt') + result = self.ssl_manager.validate_cert_chain(cert_path) # Verify the result self.assertFalse(result)