pyusvfs/usvfs.py
2024-08-07 16:10:57 +02:00

411 lines
12 KiB
Python

from __future__ import annotations
import ctypes
import ctypes.util
import dataclasses
from datetime import timedelta
from enum import Enum
from pathlib import Path
from typing import Final, Literal, overload
class _C_USVFSParameters(ctypes.Structure):
_fields_ = (
("instanceName", ctypes.c_char * 65),
("currentSHMName", ctypes.c_char * 65),
("currentInverseSHMName", ctypes.c_char * 65),
("debugMode", ctypes.c_bool),
("logLevel", ctypes.c_uint8),
("crashDumpsType", ctypes.c_uint8),
("crashDumpsPath", ctypes.c_char),
)
class _C_STARTUPINFOW(ctypes.Structure):
_fields_ = (
("cb", ctypes.c_int32),
("lpReserved", ctypes.c_wchar_p),
("lpDesktop", ctypes.c_wchar_p),
("lpTitle", ctypes.c_wchar_p),
("dwX", ctypes.c_int32),
("dwY", ctypes.c_int32),
("dwXSize", ctypes.c_int32),
("dwYSize", ctypes.c_int32),
("dwXCountChars", ctypes.c_int32),
("dwYCountChars", ctypes.c_int32),
("dwFillAttribute", ctypes.c_int32),
("dwFlags", ctypes.c_int32),
("wShowWindow", ctypes.c_int16),
("cbReserved2", ctypes.c_int16),
("lpReserved2", ctypes.POINTER(ctypes.c_int8)),
("hStdInput", ctypes.c_void_p),
("hStdOutput", ctypes.c_void_p),
("hStdError", ctypes.c_void_p),
)
class _C_PROCESS_INFORMATION(ctypes.Structure):
_fields_ = (
("hProcess", ctypes.c_void_p),
("hThread", ctypes.c_void_p),
("dwProcessId", ctypes.c_int32),
("dwThreadId", ctypes.c_int32),
)
class LogLevel(Enum):
Debug = 0
Info = 1
Warning = 2
Error = 3
class CrashDumpType(Enum):
Off = 0
Mini = 1
Data = 2
Full = 3
class USVFSParameters:
def __init__(self, usvfs: ctypes.WinDLL):
self._usvfs = usvfs
self._c_params = usvfs.usvfsCreateParameters()
def __del__(self):
self._usvfs.usvfsFreeParameters(self._c_params)
@property
def instance_name(self) -> str:
return self._c_params.contents.instanceName.decode()
@instance_name.setter
def instance_name(self, value: str) -> None:
self._usvfs.usvfsSetInstanceName(self._c_params, value.encode())
@property
def current_shm_name(self) -> str:
return self._c_params.contents.currentSHMName.decode()
@property
def current_inverse_shm_name(self) -> str:
return self._c_params.contents.currentInverseSHMName.decode()
@property
def debug_mode(self) -> bool:
return self._c_params.contents.debugMode
@debug_mode.setter
def debug_mode(self, debug: bool) -> None:
self._usvfs.usvfsSetDebugMode(self._c_params, debug)
@property
def log_level(self) -> LogLevel:
return LogLevel(self._c_params.contents.logLevel)
@log_level.setter
def log_level(self, level: LogLevel) -> None:
self._usvfs.usvfsSetLogLevel(self._c_params, level.value)
@property
def crash_dumps_type(self) -> CrashDumpType:
return CrashDumpType(self._c_params.contents.crashDumpsType)
@crash_dumps_type.setter
def crash_dumps_type(self, dump_type: CrashDumpType) -> None:
self._usvfs.usvfsSetCrashDumpType(self._c_params, dump_type.value)
@property
def crash_dumps_path(self) -> Path:
return Path(self._c_params.contents.crashDumpsPath.decode())
@crash_dumps_path.setter
def crash_dumps_path(self, path: Path | str) -> None:
self._usvfs.usvfsSetCrashDumpPath(self._c_params, str(path).encode())
_LINKFLAG_FAILIFEXISTS = 0x00000001
_LINKFLAG_MONITORCHANGES = 0x00000002
_LINKFLAG_CREATETARGET = 0x00000004
_LINKFLAG_RECURSIVE = 0x00000008
_LINKFLAG_FAILIFSKIPPED = 0x00000010
class VFS:
@dataclasses.dataclass(frozen=True)
class _Entry:
mode: Literal["file", "directory"]
source: str
target: str
flags: int
def __init__(self) -> None:
self._entries: list[VFS._Entry] = []
def link_file(
self,
source: Path | str,
destination: Path | str,
*,
fail_if_exists: bool = False,
fail_if_skipped: bool = False,
) -> None:
self._entries.append(
VFS._Entry(
mode="file",
source=str(Path(source).absolute()),
target=str(Path(destination).absolute()),
flags=(fail_if_exists and _LINKFLAG_FAILIFEXISTS)
| (fail_if_skipped and _LINKFLAG_FAILIFSKIPPED),
)
)
def link_directory(
self,
source: Path | str,
destination: Path | str,
*,
fail_if_exists: bool = False,
monitor_changes: bool = False,
recursive: bool = False,
create_target: bool = False,
fail_if_skipped: bool = False,
) -> None:
self._entries.append(
VFS._Entry(
mode="directory",
source=str(Path(source).absolute()),
target=str(Path(destination).absolute()),
flags=(fail_if_exists and _LINKFLAG_FAILIFEXISTS)
| (monitor_changes and _LINKFLAG_MONITORCHANGES)
| (recursive and _LINKFLAG_RECURSIVE)
| (create_target and _LINKFLAG_CREATETARGET)
| (fail_if_skipped and _LINKFLAG_FAILIFSKIPPED),
)
)
class Process:
def __init__(self, command: str, cwd: Path | None):
self.command: Final = command
self.cwd: Final = cwd
self.si = _C_STARTUPINFOW()
self.si.cb = ctypes.sizeof(_C_STARTUPINFOW)
self.pi = _C_PROCESS_INFORMATION()
@property
def process_handle(self) -> int:
return self.pi.hProcess
@property
def thread_handle(self) -> int:
return self.pi.hThread
@property
def process_id(self) -> int:
return self.pi.dwProcessId
@property
def thread_id(self) -> int:
return self.pi.dwThreadId
class USVFS:
instance: USVFS | None = None
def __init__(
self, path: Path | str, logging: bool | Literal["console"] = True
) -> None:
assert USVFS.instance is None
USVFS.instance = self
self._usvfs = ctypes.windll.LoadLibrary(str(path))
self._usvfs.usvfsCreateParameters.restype = ctypes.POINTER(_C_USVFSParameters)
self._usvfs.usvfsVersionString.restype = ctypes.c_char_p
self._usvfs.usvfsGetVFSProcessList2.argtypes = (
ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.POINTER(ctypes.c_int32)),
)
self._vfs: VFS | None = None
self._process_by_id: dict[int, Process] = {}
if logging:
self._usvfs.usvfsInitLogging(logging == "console")
def version(self) -> str:
return self._usvfs.usvfsVersionString().decode()
def print_debug_info(self):
self._usvfs.usvfsPrintDebugInfo()
def make_parameters(self) -> USVFSParameters:
return USVFSParameters(self._usvfs)
def make_virtual_filesystem(self) -> VFS:
return VFS()
def connect_virtual_filesystem(
self, parameters: USVFSParameters, filesystem: VFS, disconnect: bool = False
) -> None:
if self._vfs is not None and not disconnect:
raise ValueError("cannot connect to two virtual filesystem at once")
if disconnect:
self.disconnect_virtual_filesystem()
self._usvfs.usvfsConnectVFS(parameters._c_params) # pyright: ignore[reportPrivateUsage]
self._vfs = filesystem
for entry in self._vfs._entries: # pyright: ignore[reportPrivateUsage]
func = (
self._usvfs.usvfsVirtualLinkFile
if entry.mode == "file"
else self._usvfs.usvfsVirtualLinkDirectoryStatic
)
func(entry.source, entry.target, entry.flags)
def disconnect_virtual_filesystem(self) -> None:
if self._vfs is not None:
self._usvfs.usvfsDisconnectVFS()
self._vfs = None
@overload
def run_hooked_process(
self,
command: str,
*,
cwd: Path | None = None,
wait: Literal[False],
) -> Process: ...
@overload
def run_hooked_process(
self,
command: str,
cwd: Path | None = None,
wait: timedelta | Literal["infinite"] = "infinite",
) -> Process: ...
def run_hooked_process(
self,
command: str,
cwd: Path | None = None,
wait: timedelta | Literal["infinite"] | Literal[False] = "infinite",
) -> Process | int:
p = Process(command, cwd)
create_ret = self._usvfs.usvfsCreateProcessHooked(
0,
command,
0,
0,
False,
0,
0,
0 if cwd is None else str(cwd),
ctypes.pointer(p.si),
ctypes.pointer(p.pi),
)
if create_ret == 0:
raise Exception("failed to create process")
self._process_by_id[p.process_id] = p
if wait:
return self.wait_for_process(p, wait)
return p
def wait_for_process(
self, process: Process, wait: timedelta | Literal["infinite"] = "infinite"
) -> int:
wait_i: int = 2**32 - 1 # INFINITE
if wait != "infinite":
wait_i = int(wait.total_seconds() * 1000)
if (
ctypes.windll.kernel32.WaitForSingleObject(process.process_handle, wait_i)
!= 0
):
raise Exception("failed to wait for process")
exit_code = ctypes.c_int32(99)
if not ctypes.windll.kernel32.GetExitCodeProcess(
process.process_handle, ctypes.pointer(exit_code)
):
raise Exception("failed to get process exist code")
ctypes.windll.kernel32.CloseHandle(process.process_handle)
ctypes.windll.kernel32.CloseHandle(process.pi.hThread)
del self._process_by_id[process.process_id]
return exit_code.value
def get_hooked_processes(self) -> list[Process]:
size = ctypes.c_size_t(0)
buffer = ctypes.POINTER(ctypes.c_int32)()
if (
self._usvfs.usvfsGetVFSProcessList2(
ctypes.byref(size), ctypes.byref(buffer)
)
== 0
):
raise Exception("failed to get VFS process list")
if size.value == 0:
return []
processes = [
process
for i in range(0, size.value)
if (process := self._process_by_id.get(buffer[i], None)) is not None
]
# this crashes, I don't know why
# ctypes.cdll.msvcrt.free(ctypes.cast(buffer, ctypes.c_void_p))
return processes
def get_log_messages(self) -> list[str]:
size = 2048
buffer = ctypes.create_string_buffer(size)
messages: list[str] = []
while self._usvfs.usvfsGetLogMessages(buffer, size, False):
messages.append(buffer.value.decode())
return messages
def add_blacklisted_executable(self, executable: str) -> None:
self._usvfs.usvfsBlacklistExecutable(executable)
def clear_blacklist_executables(self):
self._usvfs.usvfsClearExecutableBlacklist()
def add_force_loaded_library(self, process: str, library: str | Path) -> None:
self._usvfs.usvfsForceLoadLibrary(process, str(library))
def clear_force_loaded_libraries(self):
self._usvfs.usvfsClearLibraryForceLoads()
def add_skip_file_suffix(self, suffix: str) -> None:
self._usvfs.usvfsAddSkipFileSuffix(suffix)
def clear_skip_file_suffixes(self):
self._usvfs.usvfsClearSkipFileSuffixes()
def add_skip_directory(self, name: str) -> None:
self._usvfs.usvfsAddSkipDirectory(name)
def clear_skip_directory(self):
self._usvfs.usvfsClearSkipDirectories()
def __del__(self) -> None:
del self._usvfs
USVFS.instance = None