from __future__ import annotations import ctypes import ctypes.util import dataclasses from datetime import timedelta from enum import Enum from pathlib import Path from typing import Final, Literal, overload class _C_USVFSParameters(ctypes.Structure): _fields_ = ( ("instanceName", ctypes.c_char * 65), ("currentSHMName", ctypes.c_char * 65), ("currentInverseSHMName", ctypes.c_char * 65), ("debugMode", ctypes.c_bool), ("logLevel", ctypes.c_uint8), ("crashDumpsType", ctypes.c_uint8), ("crashDumpsPath", ctypes.c_char), ) class _C_STARTUPINFOW(ctypes.Structure): _fields_ = ( ("cb", ctypes.c_int32), ("lpReserved", ctypes.c_wchar_p), ("lpDesktop", ctypes.c_wchar_p), ("lpTitle", ctypes.c_wchar_p), ("dwX", ctypes.c_int32), ("dwY", ctypes.c_int32), ("dwXSize", ctypes.c_int32), ("dwYSize", ctypes.c_int32), ("dwXCountChars", ctypes.c_int32), ("dwYCountChars", ctypes.c_int32), ("dwFillAttribute", ctypes.c_int32), ("dwFlags", ctypes.c_int32), ("wShowWindow", ctypes.c_int16), ("cbReserved2", ctypes.c_int16), ("lpReserved2", ctypes.POINTER(ctypes.c_int8)), ("hStdInput", ctypes.c_void_p), ("hStdOutput", ctypes.c_void_p), ("hStdError", ctypes.c_void_p), ) class _C_PROCESS_INFORMATION(ctypes.Structure): _fields_ = ( ("hProcess", ctypes.c_void_p), ("hThread", ctypes.c_void_p), ("dwProcessId", ctypes.c_int32), ("dwThreadId", ctypes.c_int32), ) class LogLevel(Enum): Debug = 0 Info = 1 Warning = 2 Error = 3 class CrashDumpType(Enum): Off = 0 Mini = 1 Data = 2 Full = 3 class USVFSParameters: def __init__(self, usvfs: ctypes.WinDLL): self._usvfs = usvfs self._c_params = usvfs.usvfsCreateParameters() def __del__(self): self._usvfs.usvfsFreeParameters(self._c_params) @property def instance_name(self) -> str: return self._c_params.contents.instanceName.decode() @instance_name.setter def instance_name(self, value: str) -> None: self._usvfs.usvfsSetInstanceName(self._c_params, value.encode()) @property def current_shm_name(self) -> str: return self._c_params.contents.currentSHMName.decode() @property def current_inverse_shm_name(self) -> str: return self._c_params.contents.currentInverseSHMName.decode() @property def debug_mode(self) -> bool: return self._c_params.contents.debugMode @debug_mode.setter def debug_mode(self, debug: bool) -> None: self._usvfs.usvfsSetDebugMode(self._c_params, debug) @property def log_level(self) -> LogLevel: return LogLevel(self._c_params.contents.logLevel) @log_level.setter def log_level(self, level: LogLevel) -> None: self._usvfs.usvfsSetLogLevel(self._c_params, level.value) @property def crash_dumps_type(self) -> CrashDumpType: return CrashDumpType(self._c_params.contents.crashDumpsType) @crash_dumps_type.setter def crash_dumps_type(self, dump_type: CrashDumpType) -> None: self._usvfs.usvfsSetCrashDumpType(self._c_params, dump_type.value) @property def crash_dumps_path(self) -> Path: return Path(self._c_params.contents.crashDumpsPath.decode()) @crash_dumps_path.setter def crash_dumps_path(self, path: Path | str) -> None: self._usvfs.usvfsSetCrashDumpPath(self._c_params, str(path).encode()) _LINKFLAG_FAILIFEXISTS = 0x00000001 _LINKFLAG_MONITORCHANGES = 0x00000002 _LINKFLAG_CREATETARGET = 0x00000004 _LINKFLAG_RECURSIVE = 0x00000008 _LINKFLAG_FAILIFSKIPPED = 0x00000010 class VFS: @dataclasses.dataclass(frozen=True) class _Entry: mode: Literal["file", "directory"] source: str target: str flags: int def __init__(self) -> None: self._entries: list[VFS._Entry] = [] def link_file( self, source: Path | str, destination: Path | str, *, fail_if_exists: bool = False, fail_if_skipped: bool = False, ) -> None: self._entries.append( VFS._Entry( mode="file", source=str(Path(source).absolute()), target=str(Path(destination).absolute()), flags=(fail_if_exists and _LINKFLAG_FAILIFEXISTS) | (fail_if_skipped and _LINKFLAG_FAILIFSKIPPED), ) ) def link_directory( self, source: Path | str, destination: Path | str, *, fail_if_exists: bool = False, monitor_changes: bool = False, recursive: bool = False, create_target: bool = False, fail_if_skipped: bool = False, ) -> None: self._entries.append( VFS._Entry( mode="directory", source=str(Path(source).absolute()), target=str(Path(destination).absolute()), flags=(fail_if_exists and _LINKFLAG_FAILIFEXISTS) | (monitor_changes and _LINKFLAG_MONITORCHANGES) | (recursive and _LINKFLAG_RECURSIVE) | (create_target and _LINKFLAG_CREATETARGET) | (fail_if_skipped and _LINKFLAG_FAILIFSKIPPED), ) ) class Process: def __init__(self, command: str, cwd: Path | None): self.command: Final = command self.cwd: Final = cwd self.si = _C_STARTUPINFOW() self.si.cb = ctypes.sizeof(_C_STARTUPINFOW) self.pi = _C_PROCESS_INFORMATION() @property def process_handle(self) -> int: return self.pi.hProcess @property def thread_handle(self) -> int: return self.pi.hThread @property def process_id(self) -> int: return self.pi.dwProcessId @property def thread_id(self) -> int: return self.pi.dwThreadId class USVFS: instance: USVFS | None = None def __init__( self, path: Path | str, logging: bool | Literal["console"] = True ) -> None: assert USVFS.instance is None USVFS.instance = self self._usvfs = ctypes.windll.LoadLibrary(str(path)) self._usvfs.usvfsCreateParameters.restype = ctypes.POINTER(_C_USVFSParameters) self._usvfs.usvfsVersionString.restype = ctypes.c_char_p self._usvfs.usvfsGetVFSProcessList2.argtypes = ( ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.POINTER(ctypes.c_int32)), ) self._vfs: VFS | None = None self._process_by_id: dict[int, Process] = {} if logging: self._usvfs.usvfsInitLogging(logging == "console") def version(self) -> str: return self._usvfs.usvfsVersionString().decode() def print_debug_info(self): self._usvfs.usvfsPrintDebugInfo() def make_parameters(self) -> USVFSParameters: return USVFSParameters(self._usvfs) def make_virtual_filesystem(self) -> VFS: return VFS() def connect_virtual_filesystem( self, parameters: USVFSParameters, filesystem: VFS, disconnect: bool = False ) -> None: if self._vfs is not None and not disconnect: raise ValueError("cannot connect to two virtual filesystem at once") if disconnect: self.disconnect_virtual_filesystem() self._usvfs.usvfsConnectVFS(parameters._c_params) # pyright: ignore[reportPrivateUsage] self._vfs = filesystem for entry in self._vfs._entries: # pyright: ignore[reportPrivateUsage] func = ( self._usvfs.usvfsVirtualLinkFile if entry.mode == "file" else self._usvfs.usvfsVirtualLinkDirectoryStatic ) func(entry.source, entry.target, entry.flags) def disconnect_virtual_filesystem(self) -> None: if self._vfs is not None: self._usvfs.usvfsDisconnectVFS() self._vfs = None @overload def run_hooked_process( self, command: str, *, cwd: Path | None = None, wait: Literal[False], ) -> Process: ... @overload def run_hooked_process( self, command: str, cwd: Path | None = None, wait: timedelta | Literal["infinite"] = "infinite", ) -> Process: ... def run_hooked_process( self, command: str, cwd: Path | None = None, wait: timedelta | Literal["infinite"] | Literal[False] = "infinite", ) -> Process | int: p = Process(command, cwd) create_ret = self._usvfs.usvfsCreateProcessHooked( 0, command, 0, 0, False, 0, 0, 0 if cwd is None else str(cwd), ctypes.pointer(p.si), ctypes.pointer(p.pi), ) if create_ret == 0: raise Exception("failed to create process") self._process_by_id[p.process_id] = p if wait: return self.wait_for_process(p, wait) return p def wait_for_process( self, process: Process, wait: timedelta | Literal["infinite"] = "infinite" ) -> int: wait_i: int = 2**32 - 1 # INFINITE if wait != "infinite": wait_i = int(wait.total_seconds() * 1000) if ( ctypes.windll.kernel32.WaitForSingleObject(process.process_handle, wait_i) != 0 ): raise Exception("failed to wait for process") exit_code = ctypes.c_int32(99) if not ctypes.windll.kernel32.GetExitCodeProcess( process.process_handle, ctypes.pointer(exit_code) ): raise Exception("failed to get process exist code") ctypes.windll.kernel32.CloseHandle(process.process_handle) ctypes.windll.kernel32.CloseHandle(process.pi.hThread) del self._process_by_id[process.process_id] return exit_code.value def get_hooked_processes(self) -> list[Process]: size = ctypes.c_size_t(0) buffer = ctypes.POINTER(ctypes.c_int32)() if ( self._usvfs.usvfsGetVFSProcessList2( ctypes.byref(size), ctypes.byref(buffer) ) == 0 ): raise Exception("failed to get VFS process list") if size.value == 0: return [] processes = [ process for i in range(0, size.value) if (process := self._process_by_id.get(buffer[i], None)) is not None ] # this crashes, I don't know why # ctypes.cdll.msvcrt.free(ctypes.cast(buffer, ctypes.c_void_p)) return processes def get_log_messages(self) -> list[str]: size = 2048 buffer = ctypes.create_string_buffer(size) messages: list[str] = [] while self._usvfs.usvfsGetLogMessages(buffer, size, False): messages.append(buffer.value.decode()) return messages def add_blacklisted_executable(self, executable: str) -> None: self._usvfs.usvfsBlacklistExecutable(executable) def clear_blacklist_executables(self): self._usvfs.usvfsClearExecutableBlacklist() def add_force_loaded_library(self, process: str, library: str | Path) -> None: self._usvfs.usvfsForceLoadLibrary(process, str(library)) def clear_force_loaded_libraries(self): self._usvfs.usvfsClearLibraryForceLoads() def add_skip_file_suffix(self, suffix: str) -> None: self._usvfs.usvfsAddSkipFileSuffix(suffix) def clear_skip_file_suffixes(self): self._usvfs.usvfsClearSkipFileSuffixes() def add_skip_directory(self, name: str) -> None: self._usvfs.usvfsAddSkipDirectory(name) def clear_skip_directory(self): self._usvfs.usvfsClearSkipDirectories() def __del__(self) -> None: del self._usvfs USVFS.instance = None