#!/usr/bin/env python3

import ctypes
from ctypes import Structure, POINTER, c_char_p, c_int, c_void_p, c_uint, byref, cast
import sys
import argparse
import os
import sysconfig

# Dynamically find the systemd shared library
def find_systemd_library():
    """Find the systemd shared library path for the current architecture"""
    # Get the multiarch tuple from Python's build configuration
    multiarch = sysconfig.get_config_var('MULTIARCH')
    if not multiarch:
        raise OSError("Unable to determine system multiarch tuple. This may indicate an incomplete Python installation.")

    # Try to find available systemd library versions
    systemd_lib_dir = f'/usr/lib/{multiarch}/systemd'

    # Look for different systemd versions (257 first, then 252)
    for version in ['257', '252']:
        lib_path = f'{systemd_lib_dir}/libsystemd-shared-{version}.so'
        if os.path.exists(lib_path):
            return lib_path, version

    raise OSError(f"No compatible systemd shared library found in {systemd_lib_dir}. Supported versions: 252, 257")

# Load the systemd shared library
try:
    lib_path, systemd_version = find_systemd_library()
    lib = ctypes.CDLL(lib_path)
    # Only show debug info if DEBUG environment variable is set
    if os.environ.get('DEBUG'):
        print(f"Debug: Using systemd library version {systemd_version} from {lib_path}", file=sys.stderr)
except OSError as e:
    print(f"Error: Failed to load systemd library: {e}", file=sys.stderr)
    sys.exit(1)

# Define the ConfigParseFlags enum values
CONFIG_PARSE_RELAXED = 1 << 0  # 1
CONFIG_PARSE_WARN = 1 << 1     # 2

# Define ConfFilesFlags enum values
CONF_FILES_EXECUTABLE = 1 << 0     # 1
CONF_FILES_REGULAR = 1 << 1        # 2
CONF_FILES_DIRECTORY = 1 << 2      # 4
CONF_FILES_BASENAME = 1 << 3       # 8
CONF_FILES_FILTER_MASKED = 1 << 4  # 16

# Define the ConfigTableItem struct
class ConfigTableItem(Structure):
    _fields_ = [
        ("section", c_char_p),
        ("lvalue", c_char_p),
        ("parse", c_void_p),    # function pointer
        ("ltype", c_int),
        ("data", c_void_p),
    ]

# Define function signatures based on systemd version
def setup_function_signatures():
    """Set up function signatures based on the detected systemd version"""

    # Common functions for both versions
    # strv_free function
    lib.strv_free.argtypes = [POINTER(c_char_p)]
    lib.strv_free.restype = POINTER(c_char_p)

    # strv_split_nulstr function
    lib.strv_split_nulstr.argtypes = [c_char_p]
    lib.strv_split_nulstr.restype = POINTER(c_char_p)

    # config_parse_string function
    lib.config_parse_string.argtypes = [
        c_char_p,  # unit
        c_char_p,  # filename
        c_uint,    # line
        c_char_p,  # section
        c_uint,    # section_line
        c_char_p,  # lvalue
        c_int,     # ltype
        c_char_p,  # rvalue
        c_void_p,  # data
        c_void_p   # userdata
    ]
    lib.config_parse_string.restype = c_int

    if systemd_version == '257':
        # For systemd 257, use the new API - no fallback allowed due to ABI instability
        if not hasattr(lib, 'config_parse_standard_file_with_dropins_full'):
            raise OSError(f"systemd version 257 detected but config_parse_standard_file_with_dropins_full function not found. "
                         f"This indicates an incompatible or corrupted systemd installation.")

        # config_parse_standard_file_with_dropins_full function
        lib.config_parse_standard_file_with_dropins_full.argtypes = [
            c_char_p,  # root
            c_char_p,  # main_file
            c_char_p,  # sections (nulstr format)
            c_void_p,  # lookup function pointer (ConfigItemLookup)
            c_void_p,  # table (const void *)
            c_int,     # flags (ConfigParseFlags)
            c_void_p,  # userdata
            c_void_p,  # ret_stats_by_path (Hashmap **)
            c_void_p   # ret_dropin_files (char ***)
        ]
        lib.config_parse_standard_file_with_dropins_full.restype = c_int
    else:
        # For systemd 252, use the legacy API
        setup_legacy_api()

def setup_legacy_api():
    """Set up function signatures for legacy systemd API (252)"""
    # config_parse_many function
    lib.config_parse_many.argtypes = [
        POINTER(c_char_p),  # conf_files (string array)
        POINTER(c_char_p),  # conf_file_dirs (string array)
        c_char_p,           # dropin_dirname
        c_char_p,           # sections (nulstr format)
        c_void_p,           # lookup function pointer
        POINTER(ConfigTableItem),  # table
        c_int,              # flags
        c_void_p,           # userdata
        c_void_p,           # ret_stats_by_path
        c_void_p            # ret_drop_in_files
    ]
    lib.config_parse_many.restype = c_int

    # conf_files_list_strv function
    lib.conf_files_list_strv.argtypes = [
        POINTER(POINTER(c_char_p)),  # ret (char***)
        c_char_p,                    # suffix
        c_char_p,                    # root
        c_int,                       # flags
        POINTER(c_char_p)            # dirs
    ]
    lib.conf_files_list_strv.restype = c_int

# Set up function signatures
setup_function_signatures()

# Get function pointers
config_parse_string_func = lib.config_parse_string
lookup_func = getattr(lib, 'config_item_table_lookup', None)

def python_list_to_strv(py_list):
    """Convert Python list to C string array using systemd's strv_split_nulstr"""
    if not py_list:
        return None
    nulstr = '\0'.join(py_list) + '\0'
    return lib.strv_split_nulstr(nulstr.encode('utf-8'))

def create_config_table(parse_func, queries):
    """Create ConfigTableItem array with storage for multiple parsed values"""
    # Allocate storage for each parsed value
    parsed_values = []

    # Create array with one element per query plus terminator
    table = (ConfigTableItem * (len(queries) + 1))()

    # Set up elements for each query
    for i, (section, key) in enumerate(queries):
        parsed_value = c_char_p()
        data_ptr = ctypes.pointer(parsed_value)
        parsed_values.append(parsed_value)

        table[i].section = section.encode('utf-8')
        table[i].lvalue = key.encode('utf-8')  # systemd uses 'lvalue' internally
        table[i].parse = cast(parse_func, c_void_p)
        table[i].ltype = 0
        table[i].data = cast(data_ptr, c_void_p)

    # Set up the terminator element (all NULL/0)
    table[len(queries)].section = None
    table[len(queries)].lvalue = None
    table[len(queries)].parse = None
    table[len(queries)].ltype = 0
    table[len(queries)].data = None

    return table, parsed_values

def parse_systemd_config_legacy(main_file, queries):
    """Parse systemd configuration using legacy API (systemd 252)"""
    # Standard systemd configuration paths
    conf_paths = ["/etc", "/run", "/usr/local/lib", "/usr/lib"]

    # Build list of main config files
    configs = [f"{path}/{main_file}" for path in conf_paths]

    # Create dropin_dirname
    dropin_dirname = f"{main_file}.d"

    # Get unique sections from all queries
    sections = list(set(section for section, key in queries))
    sections_nulstr = '\0'.join(sections) + '\0'

    # Create config table for all queries
    config_table, parsed_values = create_config_table(config_parse_string_func, queries)

    # Convert to strv
    conf_files_strv = python_list_to_strv(configs)
    conf_file_dirs_strv = python_list_to_strv(conf_paths)

    try:
        result = lib.config_parse_many(
            conf_files_strv,
            conf_file_dirs_strv,
            dropin_dirname.encode('utf-8'),
            sections_nulstr.encode('utf-8'),
            cast(lookup_func, c_void_p) if lookup_func else None,
            config_table,
            CONFIG_PARSE_RELAXED,
            None, None, None
        )

        # Extract results for each query
        results = []
        for parsed_value in parsed_values:
            if parsed_value.value:
                results.append(parsed_value.value.decode('utf-8'))
            else:
                results.append(None)

        return results if result == 0 else [None] * len(queries)

    except Exception as e:
        print(f"Error parsing config: {e}", file=sys.stderr)
        return [None] * len(queries)
    finally:
        # Free allocated string vectors
        if conf_files_strv:
            lib.strv_free(conf_files_strv)
        if conf_file_dirs_strv:
            lib.strv_free(conf_file_dirs_strv)

def parse_systemd_config_new(main_file, queries):
    """Parse systemd configuration using new API (systemd 257)"""
    # Get unique sections from all queries
    sections = list(set(section for section, key in queries))
    sections_nulstr = '\0'.join(sections) + '\0'

    # Create config table for all queries
    config_table, parsed_values = create_config_table(config_parse_string_func, queries)

    try:
        result = lib.config_parse_standard_file_with_dropins_full(
            None,  # root (NULL for standard paths)
            main_file.encode('utf-8'),  # main_file
            sections_nulstr.encode('utf-8'),  # sections
            cast(lookup_func, c_void_p) if lookup_func else None,  # lookup
            cast(config_table, c_void_p),  # table (as const void *)
            CONFIG_PARSE_RELAXED,  # flags
            None,  # userdata
            None,  # ret_stats_by_path
            None   # ret_dropin_files
        )

        # Extract results for each query
        results = []
        for parsed_value in parsed_values:
            if parsed_value.value:
                results.append(parsed_value.value.decode('utf-8'))
            else:
                results.append(None)

        return results if result == 0 else [None] * len(queries)

    except Exception as e:
        print(f"Error parsing config: {e}", file=sys.stderr)
        return [None] * len(queries)

def parse_systemd_config(main_file, queries):
    """Parse systemd configuration using appropriate API based on version"""
    if systemd_version == '257':
        return parse_systemd_config_new(main_file, queries)
    else:
        return parse_systemd_config_legacy(main_file, queries)

def shell_escape(value):
    """Escape a value for safe use in shell assignment"""
    if not value:
        return '""'

    # Simple shell escaping - wrap in single quotes and escape any single quotes
    escaped = value.replace("'", "'\"'\"'")
    return f"'{escaped}'"

def parse_config_key(key):
    """Parse a configuration key like 'section::key' into section and key"""
    if '::' not in key:
        print(f"Error: Invalid config key format '{key}'. Expected 'section::key'", file=sys.stderr)
        sys.exit(1)

    parts = key.split('::', 1)
    return parts[0], parts[1]

def main():
    parser = argparse.ArgumentParser(
        prog='rpi-systemd-config',
        description='SystemD Configuration Query program',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Single configuration entry
  eval "$(rpi-systemd-config systemd/system.conf CONF_VAL Manager::DefaultTimeoutStopSec)"
  echo $CONF_VAL

  # Multiple configuration entries at once
  eval "$(rpi-systemd-config systemd/system.conf TIMEOUT Manager::DefaultTimeoutStopSec LIMIT Manager::DefaultLimitNOFILE)"
  echo "Timeout: $TIMEOUT, Limit: $LIMIT"
        """
    )

    parser.add_argument('name',
                       help='Configuration file name relative to prefix (e.g., systemd/system.conf)')

    parser.add_argument('pairs', nargs='*',
                       help='Pairs of shell_variable config_key')

    # Version
    parser.add_argument('-v', '--version', action='version', version=f'rpi-systemd-config 1.1 (systemd {systemd_version})')

    args = parser.parse_args()

    # Check if we have pairs of arguments
    if len(args.pairs) % 2 != 0:
        print("Error: Arguments must be pairs of (shell_variable config_key)", file=sys.stderr)
        parser.print_help()
        sys.exit(1)

    # Check if required systemd functions are available
    if not (lookup_func and config_parse_string_func):
        print("Error: Required systemd functions not found", file=sys.stderr)
        sys.exit(1)

    # Build list of queries
    queries = []
    var_names = []

    for i in range(0, len(args.pairs), 2):
        var_name = args.pairs[i]
        config_key = args.pairs[i + 1]

        section, key = parse_config_key(config_key)
        queries.append((section, key))
        var_names.append(var_name)

    # Parse all queries at once
    results = parse_systemd_config(args.name, queries)

    # Output shell assignments
    for var_name, value in zip(var_names, results):
        if value is not None:
            escaped_value = shell_escape(value)
            print(f"{var_name}={escaped_value}")
        # Skip output entirely if value not found (like apt-config)

if __name__ == "__main__":
    main()