From 2f0cd6cdd5ce2b79165341d5a3b90278e1d7956d Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 5 Mar 2025 16:26:51 -0800 Subject: [PATCH 1/4] PTX command prototype --- .../cogs/verify_run_cog.py | 180 +++++++- src/discord-cluster-manager/sandbox.py | 433 ++++++++++++++++++ 2 files changed, 612 insertions(+), 1 deletion(-) create mode 100644 src/discord-cluster-manager/sandbox.py diff --git a/src/discord-cluster-manager/cogs/verify_run_cog.py b/src/discord-cluster-manager/cogs/verify_run_cog.py index e81e2ae6..83862f01 100644 --- a/src/discord-cluster-manager/cogs/verify_run_cog.py +++ b/src/discord-cluster-manager/cogs/verify_run_cog.py @@ -1,6 +1,8 @@ import asyncio import datetime import re +import subprocess +import tempfile import uuid from pathlib import Path from unittest.mock import AsyncMock @@ -11,7 +13,7 @@ from cogs.github_cog import GitHubCog from cogs.leaderboard_cog import LeaderboardSubmitCog from cogs.modal_cog import ModalCog -from consts import SubmissionMode +from consts import CUDA_FLAGS, GPU_TO_SM, SubmissionMode from discord import app_commands from discord.app_commands import Choice from discord.ext import commands @@ -261,6 +263,182 @@ async def verify_submission( # noqa: C901 if report_success: reports.append(f"✅ {run_id:20} {mode.name} behaved as expected") + async def generate_ptx_code(self, source_code: str, gpu_type: str, include_sass: bool = False) -> tuple[bool, str]: + """ + Generate PTX code for a CUDA submission. + + Args: + source_code (str): The CUDA source code + gpu_type (str): The GPU architecture to target + include_sass (bool): Whether to include SASS assembly code + + Returns: + tuple[bool, str]: Success status and the PTX output or error message + """ + # Get the SM architecture code for the specified GPU type + arch = GPU_TO_SM.get(gpu_type) + if not arch: + return False, f"Unknown GPU type: {gpu_type}. Available types: {', '.join(GPU_TO_SM.keys())}" + + try: + # Create a temporary directory for the compilation + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + source_file = temp_path / "submission.cu" + + # Write the source code to a file + source_file.write_text(source_code) + + # Prepare the compilation command with PTX output flag + ptx_flags = CUDA_FLAGS.copy() + ["-ptx"] + + # Add sass generation flag if requested + if include_sass: + ptx_flags.append("-Xptxas=-v") # Verbose output with sass info + + arch_flag = f"-gencode=arch=compute_{arch},code=compute_{arch}" + + command = ["nvcc"] + ptx_flags + [str(source_file), arch_flag, "-o", str(temp_path / "output.ptx")] + + # Check if nvcc is available + nvcc_check = subprocess.run(["which", "nvcc"], capture_output=True, text=True) + if nvcc_check.returncode != 0: + return False, "NVCC (CUDA compiler) not found. Is CUDA installed?" + + # Run the compilation + process = subprocess.run(command, capture_output=True, text=True) + + # Prepare the output with both stderr (for SASS if requested) and the PTX file + result = "" + + # Include compilation output which contains SASS information + if include_sass and process.stderr: + result += "SASS Assembly Information:\n" + result += "-" * 40 + "\n" + result += process.stderr + "\n" + result += "-" * 40 + "\n\n" + + if process.returncode != 0: + # Compilation failed + return False, f"PTX generation failed:\n{process.stderr}" + + # Read the PTX file + ptx_file = temp_path / "output.ptx" + if ptx_file.exists(): + result += "PTX Code:\n" + result += "-" * 40 + "\n" + result += ptx_file.read_text() + return True, result + else: + return False, "PTX file was not generated" + except Exception as e: + return False, f"Error generating PTX: {str(e)}" + + @app_commands.command(name="ptx") + @app_commands.describe( + submission="The CUDA submission file (.cu extension)", + gpu_type="The GPU architecture to target", + include_sass="Whether to include SASS/assembly output", + as_file="Return the PTX code as a downloadable file instead of text messages" + ) + @app_commands.choices( + gpu_type=[ + Choice(name=gpu, value=gpu) for gpu in GPU_TO_SM.keys() + ] + ) + @with_error_handling + async def ptx_command(self, interaction: discord.Interaction, + submission: discord.Attachment, + gpu_type: Choice[str] = None, + include_sass: bool = False, + as_file: bool = False): + """ + Generate PTX code from a CUDA submission. + + Parameters + ------------ + submission: File + The CUDA submission file (.cu extension) + gpu_type: Choice[str] + The GPU architecture to target + include_sass: bool + Whether to include SASS assembly code in the output + as_file: bool + Return the PTX code as a downloadable file instead of text messages + """ + if not interaction.response.is_done(): + await interaction.response.defer() + + # Validate the file extension + if not submission.filename.endswith('.cu'): + await send_discord_message(interaction, "❌ Only .cu file extensions are supported for PTX generation") + return + + # Set default GPU type to T4 if not specified + target_gpu = gpu_type.value if gpu_type else "T4" + + try: + # Read the submission file + content = await submission.read() + source_code = content.decode('utf-8') + + # Create a thread for the PTX generation + thread_name = f"PTX Generation - {submission.filename} - {target_gpu}" + if include_sass: + thread_name += " with SASS" + + thread = await interaction.channel.create_thread( + name=thread_name, + type=discord.ChannelType.public_thread, + ) + + await thread.send(f"Generating PTX code for {submission.filename} targeting {target_gpu}..." + + (" (including SASS output)" if include_sass else "")) + + # Generate the PTX code + success, result = await self.generate_ptx_code(source_code, target_gpu, include_sass) + + if success: + if as_file: + # Create a temporary file containing the PTX output + with tempfile.NamedTemporaryFile('w', suffix='.ptx', delete=False) as temp_file: + temp_file.write(result) + temp_file_path = temp_file.name + + # Get the base filename without extension + base_filename = Path(submission.filename).stem + output_filename = f"{base_filename}_{target_gpu}.ptx" + + # Send the file + await thread.send( + f"PTX code for {submission.filename} targeting {target_gpu}:", + file=discord.File(temp_file_path, filename=output_filename) + ) + + # Remove the temporary file + Path(temp_file_path).unlink(missing_ok=True) + else: + # Split the PTX code into chunks if it's too long for Discord + max_msg_length = 1900 # Slightly less than 2000 to account for markdown + chunks = [result[i:i+max_msg_length] for i in range(0, len(result), max_msg_length)] + + for i, chunk in enumerate(chunks): + await thread.send(f"```{chunk}```") + + # Send a summary message + await thread.send(f"✅ PTX code generation complete for {target_gpu} GPU" + + (" with SASS assembly" if include_sass else "")) + else: + # Send the error message + await thread.send(f"❌ Failed to generate PTX code: {result}") + + # Notify user in the original channel + await send_discord_message(interaction, f"PTX generation for {submission.filename} is complete. Check the thread for results.") + + except Exception as e: + logger.error(f"Error generating PTX: {e}", exc_info=True) + await send_discord_message(interaction, f"❌ Error generating PTX: {str(e)}") + @app_commands.command(name="verifyruns") async def verify_runs(self, interaction: discord.Interaction): """Verify runs on Modal, GitHub Nvidia, and GitHub AMD.""" diff --git a/src/discord-cluster-manager/sandbox.py b/src/discord-cluster-manager/sandbox.py new file mode 100644 index 00000000..c3f39f73 --- /dev/null +++ b/src/discord-cluster-manager/sandbox.py @@ -0,0 +1,433 @@ +import sys +import builtins +import importlib +import importlib.abc +import importlib.machinery +import types +import os +import re +import ast +import signal +import resource +import threading +from contextlib import contextmanager +from typing import Dict, List, Set, Optional, Any, Callable, Tuple + +# List of dangerous functions and modules to restrict +RESTRICTED_MODULES = { + 'os': {'system', 'popen', 'spawn', 'exec', 'execl', 'execle', 'execlp', 'execlpe', + 'execv', 'execve', 'execvp', 'execvpe', 'startfile', 'rename', 'remove', 'unlink', + 'rmdir', 'mkdir', 'makedirs', 'fork', 'forkpty', 'killpg', 'kill', '_exit', 'setuid', + 'seteuid', 'setreuid', 'setgid', 'setegid', 'setregid', 'chdir', 'fchdir', 'chroot', + 'chmod', 'chown', 'lchown', 'fchown', 'symlink', 'truncate', 'ftruncate', 'putenv', + 'unsetenv', 'environ'}, + 'sys': {'exit', '_exit', 'modules', 'path', 'meta_path', 'exitfunc', 'displayhook'}, + 'subprocess': {'*'}, # Block the entire module + 'multiprocessing': {'*'}, # Block the entire module + 'importlib': {'*'}, # Block direct importlib usage + 'builtins': {'__import__', 'eval', 'exec', 'compile', 'open', 'input', 'breakpoint'}, + 'ctypes': {'*'}, # Block the entire module + 'shutil': {'copyfileobj', 'copyfile', 'copy', 'copy2', 'copytree', 'move', 'rmtree'}, + 'socket': {'*'}, # Block network operations + 'urllib': {'*'}, # Block network operations + 'urllib.request': {'*'}, # Block network operations + 'http': {'*'}, # Block network operations + 'requests': {'*'}, # Block network operations + 'pip': {'*'}, # Block pip installations + 'setuptools': {'*'}, # Block package installations + 'pkg_resources': {'*'}, # Block package management + 'distutils': {'*'}, # Block package installations +} + +# List of allowed imports for computation +ALLOWED_IMPORTS = { + 'torch', 'numpy', 'math', 'random', 'time', 'functools', 'itertools', + 'collections', 'copy', 'datetime', 'json', 're', 'typing', + # Allow specific standard library modules that are safe + 'abc', 'array', 'bisect', 'calendar', 'contextlib', 'decimal', 'enum', + 'fractions', 'heapq', 'numbers', 'statistics', 'string', 'textwrap', + # Allow submodules from torch + 'torch.nn', 'torch.optim', 'torch.cuda', 'torch.utils', 'torch.distributions', + # Allow NumPy submodules + 'numpy.random', 'numpy.linalg', 'numpy.fft' +} + +# Set to track currently allowed imports (can be dynamically modified) +_CURRENTLY_ALLOWED = set(ALLOWED_IMPORTS) + +# Original built-in import function +_original_import = builtins.__import__ + +# Track if sandbox is active +_sandbox_active = False + +class RestrictedImportError(ImportError): + """Raised when an import is blocked by the sandbox.""" + pass + +def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): + """ + A replacement for the built-in __import__ function that restricts imports. + """ + if not _sandbox_active: + return _original_import(name, globals, locals, fromlist, level) + + # Check if the module is in the allowed list + if name in _CURRENTLY_ALLOWED: + return _original_import(name, globals, locals, fromlist, level) + + # Check if it's a submodule of an allowed module + for allowed in _CURRENTLY_ALLOWED: + if name.startswith(allowed + '.'): + return _original_import(name, globals, locals, fromlist, level) + + # Explicit check for built-in modules that should always be allowed + if name in sys.builtin_module_names and name not in RESTRICTED_MODULES: + return _original_import(name, globals, locals, fromlist, level) + + # Check if the module is in the restricted list + for restricted_module, restricted_attrs in RESTRICTED_MODULES.items(): + if name == restricted_module: + if '*' in restricted_attrs: + raise RestrictedImportError( + f"Import of module '{name}' is not allowed in the sandbox environment. " + f"This module is restricted for security reasons." + ) + + # If only specific attributes are restricted, wrap the module + module = _original_import(name, globals, locals, fromlist, level) + return _create_restricted_module(module, restricted_attrs) + + # For any other imports, log and deny + raise RestrictedImportError( + f"Import of module '{name}' is not allowed in the sandbox environment. " + f"Only specific modules required for computational tasks are permitted." + ) + +def _create_restricted_module(module, restricted_attrs): + """Create a wrapper around a module that blocks access to specific attributes.""" + class RestrictedModule: + def __init__(self, module, restricted_attrs): + self._module = module + self._restricted_attrs = restricted_attrs + + def __getattr__(self, name): + if name in self._restricted_attrs: + raise AttributeError( + f"Access to '{self._module.__name__}.{name}' is not allowed in the sandbox environment." + ) + return getattr(self._module, name) + + return RestrictedModule(module, restricted_attrs) + +def _restricted_exec(code, globals=None, locals=None): + """ + A replacement for the built-in exec function that raises an error. + """ + raise RuntimeError("Use of exec() is not allowed in the sandbox environment.") + +def _restricted_eval(expr, globals=None, locals=None): + """ + A replacement for the built-in eval function that raises an error. + """ + raise RuntimeError("Use of eval() is not allowed in the sandbox environment.") + +def _restricted_compile(*args, **kwargs): + """ + A replacement for the built-in compile function that raises an error. + """ + raise RuntimeError("Use of compile() is not allowed in the sandbox environment.") + +def _restricted_open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None): + """ + A replacement for the built-in open function that only allows reading from specific directories. + """ + # Convert to absolute path + if not os.path.isabs(file): + file = os.path.abspath(file) + + # Only allow reading, not writing + if 'w' in mode or 'a' in mode or '+' in mode or 'x' in mode: + raise IOError(f"Write operations are not allowed in the sandbox environment.") + + # Call the original open function + return _original_open(file, mode, buffering, encoding, errors, newline, closefd, opener) + +# Original built-in functions +_original_open = builtins.open +_original_exec = builtins.exec +_original_eval = builtins.eval +_original_compile = builtins.compile + +class Sandbox: + """ + A context manager that creates a restricted execution environment for user code. + """ + def __init__(self, additional_allowed_imports=None): + self.additional_allowed_imports = additional_allowed_imports or [] + self.original_builtins = {} + self.original_sys_modules = {} + + def __enter__(self): + global _sandbox_active + _sandbox_active = True + + # Add any additional allowed imports + for module in self.additional_allowed_imports: + _CURRENTLY_ALLOWED.add(module) + + # Save original built-ins and replace with restricted versions + self.original_builtins = { + '__import__': builtins.__import__, + 'exec': builtins.exec, + 'eval': builtins.eval, + 'compile': builtins.compile, + 'open': builtins.open, + } + + # Replace built-ins with restricted versions + builtins.__import__ = _safe_import + builtins.exec = _restricted_exec + builtins.eval = _restricted_eval + builtins.compile = _restricted_compile + builtins.open = _restricted_open + + # Disable sys.modules manipulation + self.original_sys_modules = sys.modules.copy() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _sandbox_active + _sandbox_active = False + + # Restore original built-ins + for name, func in self.original_builtins.items(): + setattr(builtins, name, func) + + # Restore sys.modules + for module_name in list(sys.modules.keys()): + if module_name not in self.original_sys_modules: + del sys.modules[module_name] + + # Remove any additional allowed imports we added + for module in self.additional_allowed_imports: + if module in _CURRENTLY_ALLOWED: + _CURRENTLY_ALLOWED.remove(module) + + return False # Don't suppress exceptions + +# AST-based static code analyzer for submissions +class CodeAnalyzer(ast.NodeVisitor): + """ + Static code analyzer that detects potentially malicious patterns in Python code. + """ + def __init__(self): + self.issues = [] + self.imports = set() + self.suspicious_calls = [] + + def visit_Import(self, node): + for name in node.names: + self.imports.add(name.name) + # Check for suspicious imports + if name.name in RESTRICTED_MODULES or name.name.split('.')[0] in RESTRICTED_MODULES: + self.issues.append(f"Suspicious import: {name.name}") + self.generic_visit(node) + + def visit_ImportFrom(self, node): + module_name = node.module + self.imports.add(module_name) + # Check for suspicious imports + if module_name in RESTRICTED_MODULES or module_name.split('.')[0] in RESTRICTED_MODULES: + self.issues.append(f"Suspicious import from: {module_name}") + self.generic_visit(node) + + def visit_Call(self, node): + # Check for calls to exec, eval, etc. + if isinstance(node.func, ast.Name): + if node.func.id in {'exec', 'eval', 'compile', '__import__'}: + self.issues.append(f"Suspicious call to {node.func.id}()") + elif isinstance(node.func, ast.Attribute): + if isinstance(node.func.value, ast.Name): + # Check for os.system, subprocess.call, etc. + if node.func.value.id in {'os', 'subprocess', 'sys', 'shutil'}: + self.suspicious_calls.append(f"{node.func.value.id}.{node.func.attr}") + self.issues.append(f"Suspicious call to {node.func.value.id}.{node.func.attr}()") + self.generic_visit(node) + +def analyze_code(code: str) -> Tuple[bool, List[str]]: + """ + Performs static analysis on the code to detect potentially harmful patterns. + + Args: + code: The source code to analyze + + Returns: + Tuple of (is_safe, issues) + """ + try: + tree = ast.parse(code) + analyzer = CodeAnalyzer() + analyzer.visit(tree) + + is_safe = len(analyzer.issues) == 0 + return is_safe, analyzer.issues + except SyntaxError as e: + return False, [f"Syntax error in code: {e}"] + except Exception as e: + return False, [f"Error analyzing code: {e}"] + +def execute_submission_safely(submission_path, func_name="custom_kernel", *args, **kwargs): + """ + Safely execute a user's submission by importing it in a sandboxed environment. + + Args: + submission_path: Path to the submission.py file + func_name: Name of the function to call in the submission + *args, **kwargs: Arguments to pass to the function + + Returns: + The result of the function call + """ + # Get directory and filename + dir_path = os.path.dirname(os.path.abspath(submission_path)) + file_name = os.path.basename(submission_path) + module_name = os.path.splitext(file_name)[0] + + # Read the code and perform static analysis + with open(submission_path, 'r') as f: + code = f.read() + + is_safe, issues = analyze_code(code) + if not is_safe: + issues_str = "\n".join(issues) + raise SecurityError(f"The submission contains potentially unsafe code:\n{issues_str}") + + # Save original directory and switch to submission directory + original_dir = os.getcwd() + original_path = sys.path.copy() + + try: + os.chdir(dir_path) + if dir_path not in sys.path: + sys.path.insert(0, dir_path) + + # Use the sandbox for importing and executing + with Sandbox(): + # Import the module + module = importlib.import_module(module_name) + + # Get the function + if not hasattr(module, func_name): + raise AttributeError(f"Module {module_name} does not have a function named {func_name}") + + func = getattr(module, func_name) + + # Call the function + return func(*args, **kwargs) + finally: + # Restore original directory and path + os.chdir(original_dir) + sys.path = original_path + +class SecurityError(Exception): + """Raised when a security violation is detected.""" + pass + +class TimeoutError(Exception): + """Raised when execution times out.""" + pass + +class MemoryLimitError(Exception): + """Raised when execution exceeds memory limits.""" + pass + +@contextmanager +def time_limit(seconds): + """ + Sets a time limit on code execution using SIGALRM. + + Args: + seconds: Maximum execution time in seconds + + Raises: + TimeoutError: If execution time exceeds the limit + """ + # Define signal handler + def signal_handler(signum, frame): + raise TimeoutError(f"Code execution timed out after {seconds} seconds") + + # Set signal handler and alarm + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(seconds) + + try: + yield + finally: + # Reset the alarm + signal.alarm(0) + +def set_memory_limit(memory_mb): + """ + Sets a memory limit for the current process. + + Args: + memory_mb: Maximum memory in megabytes + """ + soft, hard = resource.getrlimit(resource.RLIMIT_AS) + memory_bytes = memory_mb * 1024 * 1024 + resource.setrlimit(resource.RLIMIT_AS, (memory_bytes, hard)) + +def run_with_limitations(func, args=None, kwargs=None, timeout_sec=10, memory_mb=1024): + """ + Runs a function with time and memory limitations. + + Args: + func: The function to run + args: Arguments to pass to the function + kwargs: Keyword arguments to pass to the function + timeout_sec: Maximum execution time in seconds + memory_mb: Maximum memory usage in megabytes + + Returns: + The result of the function call + + Raises: + TimeoutError: If execution time exceeds the limit + MemoryLimitError: If memory usage exceeds the limit + """ + args = args or () + kwargs = kwargs or {} + + result = [None] + exception = [None] + + def target(): + try: + # Set memory limit for this thread + set_memory_limit(memory_mb) + result[0] = func(*args, **kwargs) + except Exception as e: + exception[0] = e + + # Create and start thread + thread = threading.Thread(target=target) + + try: + with time_limit(timeout_sec): + thread.start() + thread.join(timeout=timeout_sec) + except TimeoutError as e: + # If we get here, the code timed out + raise e + + if thread.is_alive(): + # Thread is still running after timeout + raise TimeoutError(f"Code execution timed out after {timeout_sec} seconds") + + if exception[0]: + # Re-raise any exception from the thread + raise exception[0] + + return result[0] \ No newline at end of file From aa47461a3f8d5b0b3f30b1b7b823dadaf1ea5cdb Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 5 Mar 2025 16:39:20 -0800 Subject: [PATCH 2/4] Delete src/discord-cluster-manager/sandbox.py --- src/discord-cluster-manager/sandbox.py | 433 ------------------------- 1 file changed, 433 deletions(-) delete mode 100644 src/discord-cluster-manager/sandbox.py diff --git a/src/discord-cluster-manager/sandbox.py b/src/discord-cluster-manager/sandbox.py deleted file mode 100644 index c3f39f73..00000000 --- a/src/discord-cluster-manager/sandbox.py +++ /dev/null @@ -1,433 +0,0 @@ -import sys -import builtins -import importlib -import importlib.abc -import importlib.machinery -import types -import os -import re -import ast -import signal -import resource -import threading -from contextlib import contextmanager -from typing import Dict, List, Set, Optional, Any, Callable, Tuple - -# List of dangerous functions and modules to restrict -RESTRICTED_MODULES = { - 'os': {'system', 'popen', 'spawn', 'exec', 'execl', 'execle', 'execlp', 'execlpe', - 'execv', 'execve', 'execvp', 'execvpe', 'startfile', 'rename', 'remove', 'unlink', - 'rmdir', 'mkdir', 'makedirs', 'fork', 'forkpty', 'killpg', 'kill', '_exit', 'setuid', - 'seteuid', 'setreuid', 'setgid', 'setegid', 'setregid', 'chdir', 'fchdir', 'chroot', - 'chmod', 'chown', 'lchown', 'fchown', 'symlink', 'truncate', 'ftruncate', 'putenv', - 'unsetenv', 'environ'}, - 'sys': {'exit', '_exit', 'modules', 'path', 'meta_path', 'exitfunc', 'displayhook'}, - 'subprocess': {'*'}, # Block the entire module - 'multiprocessing': {'*'}, # Block the entire module - 'importlib': {'*'}, # Block direct importlib usage - 'builtins': {'__import__', 'eval', 'exec', 'compile', 'open', 'input', 'breakpoint'}, - 'ctypes': {'*'}, # Block the entire module - 'shutil': {'copyfileobj', 'copyfile', 'copy', 'copy2', 'copytree', 'move', 'rmtree'}, - 'socket': {'*'}, # Block network operations - 'urllib': {'*'}, # Block network operations - 'urllib.request': {'*'}, # Block network operations - 'http': {'*'}, # Block network operations - 'requests': {'*'}, # Block network operations - 'pip': {'*'}, # Block pip installations - 'setuptools': {'*'}, # Block package installations - 'pkg_resources': {'*'}, # Block package management - 'distutils': {'*'}, # Block package installations -} - -# List of allowed imports for computation -ALLOWED_IMPORTS = { - 'torch', 'numpy', 'math', 'random', 'time', 'functools', 'itertools', - 'collections', 'copy', 'datetime', 'json', 're', 'typing', - # Allow specific standard library modules that are safe - 'abc', 'array', 'bisect', 'calendar', 'contextlib', 'decimal', 'enum', - 'fractions', 'heapq', 'numbers', 'statistics', 'string', 'textwrap', - # Allow submodules from torch - 'torch.nn', 'torch.optim', 'torch.cuda', 'torch.utils', 'torch.distributions', - # Allow NumPy submodules - 'numpy.random', 'numpy.linalg', 'numpy.fft' -} - -# Set to track currently allowed imports (can be dynamically modified) -_CURRENTLY_ALLOWED = set(ALLOWED_IMPORTS) - -# Original built-in import function -_original_import = builtins.__import__ - -# Track if sandbox is active -_sandbox_active = False - -class RestrictedImportError(ImportError): - """Raised when an import is blocked by the sandbox.""" - pass - -def _safe_import(name, globals=None, locals=None, fromlist=(), level=0): - """ - A replacement for the built-in __import__ function that restricts imports. - """ - if not _sandbox_active: - return _original_import(name, globals, locals, fromlist, level) - - # Check if the module is in the allowed list - if name in _CURRENTLY_ALLOWED: - return _original_import(name, globals, locals, fromlist, level) - - # Check if it's a submodule of an allowed module - for allowed in _CURRENTLY_ALLOWED: - if name.startswith(allowed + '.'): - return _original_import(name, globals, locals, fromlist, level) - - # Explicit check for built-in modules that should always be allowed - if name in sys.builtin_module_names and name not in RESTRICTED_MODULES: - return _original_import(name, globals, locals, fromlist, level) - - # Check if the module is in the restricted list - for restricted_module, restricted_attrs in RESTRICTED_MODULES.items(): - if name == restricted_module: - if '*' in restricted_attrs: - raise RestrictedImportError( - f"Import of module '{name}' is not allowed in the sandbox environment. " - f"This module is restricted for security reasons." - ) - - # If only specific attributes are restricted, wrap the module - module = _original_import(name, globals, locals, fromlist, level) - return _create_restricted_module(module, restricted_attrs) - - # For any other imports, log and deny - raise RestrictedImportError( - f"Import of module '{name}' is not allowed in the sandbox environment. " - f"Only specific modules required for computational tasks are permitted." - ) - -def _create_restricted_module(module, restricted_attrs): - """Create a wrapper around a module that blocks access to specific attributes.""" - class RestrictedModule: - def __init__(self, module, restricted_attrs): - self._module = module - self._restricted_attrs = restricted_attrs - - def __getattr__(self, name): - if name in self._restricted_attrs: - raise AttributeError( - f"Access to '{self._module.__name__}.{name}' is not allowed in the sandbox environment." - ) - return getattr(self._module, name) - - return RestrictedModule(module, restricted_attrs) - -def _restricted_exec(code, globals=None, locals=None): - """ - A replacement for the built-in exec function that raises an error. - """ - raise RuntimeError("Use of exec() is not allowed in the sandbox environment.") - -def _restricted_eval(expr, globals=None, locals=None): - """ - A replacement for the built-in eval function that raises an error. - """ - raise RuntimeError("Use of eval() is not allowed in the sandbox environment.") - -def _restricted_compile(*args, **kwargs): - """ - A replacement for the built-in compile function that raises an error. - """ - raise RuntimeError("Use of compile() is not allowed in the sandbox environment.") - -def _restricted_open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None): - """ - A replacement for the built-in open function that only allows reading from specific directories. - """ - # Convert to absolute path - if not os.path.isabs(file): - file = os.path.abspath(file) - - # Only allow reading, not writing - if 'w' in mode or 'a' in mode or '+' in mode or 'x' in mode: - raise IOError(f"Write operations are not allowed in the sandbox environment.") - - # Call the original open function - return _original_open(file, mode, buffering, encoding, errors, newline, closefd, opener) - -# Original built-in functions -_original_open = builtins.open -_original_exec = builtins.exec -_original_eval = builtins.eval -_original_compile = builtins.compile - -class Sandbox: - """ - A context manager that creates a restricted execution environment for user code. - """ - def __init__(self, additional_allowed_imports=None): - self.additional_allowed_imports = additional_allowed_imports or [] - self.original_builtins = {} - self.original_sys_modules = {} - - def __enter__(self): - global _sandbox_active - _sandbox_active = True - - # Add any additional allowed imports - for module in self.additional_allowed_imports: - _CURRENTLY_ALLOWED.add(module) - - # Save original built-ins and replace with restricted versions - self.original_builtins = { - '__import__': builtins.__import__, - 'exec': builtins.exec, - 'eval': builtins.eval, - 'compile': builtins.compile, - 'open': builtins.open, - } - - # Replace built-ins with restricted versions - builtins.__import__ = _safe_import - builtins.exec = _restricted_exec - builtins.eval = _restricted_eval - builtins.compile = _restricted_compile - builtins.open = _restricted_open - - # Disable sys.modules manipulation - self.original_sys_modules = sys.modules.copy() - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - global _sandbox_active - _sandbox_active = False - - # Restore original built-ins - for name, func in self.original_builtins.items(): - setattr(builtins, name, func) - - # Restore sys.modules - for module_name in list(sys.modules.keys()): - if module_name not in self.original_sys_modules: - del sys.modules[module_name] - - # Remove any additional allowed imports we added - for module in self.additional_allowed_imports: - if module in _CURRENTLY_ALLOWED: - _CURRENTLY_ALLOWED.remove(module) - - return False # Don't suppress exceptions - -# AST-based static code analyzer for submissions -class CodeAnalyzer(ast.NodeVisitor): - """ - Static code analyzer that detects potentially malicious patterns in Python code. - """ - def __init__(self): - self.issues = [] - self.imports = set() - self.suspicious_calls = [] - - def visit_Import(self, node): - for name in node.names: - self.imports.add(name.name) - # Check for suspicious imports - if name.name in RESTRICTED_MODULES or name.name.split('.')[0] in RESTRICTED_MODULES: - self.issues.append(f"Suspicious import: {name.name}") - self.generic_visit(node) - - def visit_ImportFrom(self, node): - module_name = node.module - self.imports.add(module_name) - # Check for suspicious imports - if module_name in RESTRICTED_MODULES or module_name.split('.')[0] in RESTRICTED_MODULES: - self.issues.append(f"Suspicious import from: {module_name}") - self.generic_visit(node) - - def visit_Call(self, node): - # Check for calls to exec, eval, etc. - if isinstance(node.func, ast.Name): - if node.func.id in {'exec', 'eval', 'compile', '__import__'}: - self.issues.append(f"Suspicious call to {node.func.id}()") - elif isinstance(node.func, ast.Attribute): - if isinstance(node.func.value, ast.Name): - # Check for os.system, subprocess.call, etc. - if node.func.value.id in {'os', 'subprocess', 'sys', 'shutil'}: - self.suspicious_calls.append(f"{node.func.value.id}.{node.func.attr}") - self.issues.append(f"Suspicious call to {node.func.value.id}.{node.func.attr}()") - self.generic_visit(node) - -def analyze_code(code: str) -> Tuple[bool, List[str]]: - """ - Performs static analysis on the code to detect potentially harmful patterns. - - Args: - code: The source code to analyze - - Returns: - Tuple of (is_safe, issues) - """ - try: - tree = ast.parse(code) - analyzer = CodeAnalyzer() - analyzer.visit(tree) - - is_safe = len(analyzer.issues) == 0 - return is_safe, analyzer.issues - except SyntaxError as e: - return False, [f"Syntax error in code: {e}"] - except Exception as e: - return False, [f"Error analyzing code: {e}"] - -def execute_submission_safely(submission_path, func_name="custom_kernel", *args, **kwargs): - """ - Safely execute a user's submission by importing it in a sandboxed environment. - - Args: - submission_path: Path to the submission.py file - func_name: Name of the function to call in the submission - *args, **kwargs: Arguments to pass to the function - - Returns: - The result of the function call - """ - # Get directory and filename - dir_path = os.path.dirname(os.path.abspath(submission_path)) - file_name = os.path.basename(submission_path) - module_name = os.path.splitext(file_name)[0] - - # Read the code and perform static analysis - with open(submission_path, 'r') as f: - code = f.read() - - is_safe, issues = analyze_code(code) - if not is_safe: - issues_str = "\n".join(issues) - raise SecurityError(f"The submission contains potentially unsafe code:\n{issues_str}") - - # Save original directory and switch to submission directory - original_dir = os.getcwd() - original_path = sys.path.copy() - - try: - os.chdir(dir_path) - if dir_path not in sys.path: - sys.path.insert(0, dir_path) - - # Use the sandbox for importing and executing - with Sandbox(): - # Import the module - module = importlib.import_module(module_name) - - # Get the function - if not hasattr(module, func_name): - raise AttributeError(f"Module {module_name} does not have a function named {func_name}") - - func = getattr(module, func_name) - - # Call the function - return func(*args, **kwargs) - finally: - # Restore original directory and path - os.chdir(original_dir) - sys.path = original_path - -class SecurityError(Exception): - """Raised when a security violation is detected.""" - pass - -class TimeoutError(Exception): - """Raised when execution times out.""" - pass - -class MemoryLimitError(Exception): - """Raised when execution exceeds memory limits.""" - pass - -@contextmanager -def time_limit(seconds): - """ - Sets a time limit on code execution using SIGALRM. - - Args: - seconds: Maximum execution time in seconds - - Raises: - TimeoutError: If execution time exceeds the limit - """ - # Define signal handler - def signal_handler(signum, frame): - raise TimeoutError(f"Code execution timed out after {seconds} seconds") - - # Set signal handler and alarm - signal.signal(signal.SIGALRM, signal_handler) - signal.alarm(seconds) - - try: - yield - finally: - # Reset the alarm - signal.alarm(0) - -def set_memory_limit(memory_mb): - """ - Sets a memory limit for the current process. - - Args: - memory_mb: Maximum memory in megabytes - """ - soft, hard = resource.getrlimit(resource.RLIMIT_AS) - memory_bytes = memory_mb * 1024 * 1024 - resource.setrlimit(resource.RLIMIT_AS, (memory_bytes, hard)) - -def run_with_limitations(func, args=None, kwargs=None, timeout_sec=10, memory_mb=1024): - """ - Runs a function with time and memory limitations. - - Args: - func: The function to run - args: Arguments to pass to the function - kwargs: Keyword arguments to pass to the function - timeout_sec: Maximum execution time in seconds - memory_mb: Maximum memory usage in megabytes - - Returns: - The result of the function call - - Raises: - TimeoutError: If execution time exceeds the limit - MemoryLimitError: If memory usage exceeds the limit - """ - args = args or () - kwargs = kwargs or {} - - result = [None] - exception = [None] - - def target(): - try: - # Set memory limit for this thread - set_memory_limit(memory_mb) - result[0] = func(*args, **kwargs) - except Exception as e: - exception[0] = e - - # Create and start thread - thread = threading.Thread(target=target) - - try: - with time_limit(timeout_sec): - thread.start() - thread.join(timeout=timeout_sec) - except TimeoutError as e: - # If we get here, the code timed out - raise e - - if thread.is_alive(): - # Thread is still running after timeout - raise TimeoutError(f"Code execution timed out after {timeout_sec} seconds") - - if exception[0]: - # Re-raise any exception from the thread - raise exception[0] - - return result[0] \ No newline at end of file From cfcbce1d4d6435e5afe2e872242baefc52c22090 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 5 Mar 2025 16:47:45 -0800 Subject: [PATCH 3/4] move --- .../cogs/submit_cog.py | 182 +++++++++++++++++- .../cogs/verify_run_cog.py | 179 +---------------- 2 files changed, 182 insertions(+), 179 deletions(-) diff --git a/src/discord-cluster-manager/cogs/submit_cog.py b/src/discord-cluster-manager/cogs/submit_cog.py index 7574739b..c4f813be 100644 --- a/src/discord-cluster-manager/cogs/submit_cog.py +++ b/src/discord-cluster-manager/cogs/submit_cog.py @@ -1,13 +1,17 @@ from enum import Enum from typing import TYPE_CHECKING, Optional, Tuple, Type +import tempfile +import subprocess +from pathlib import Path if TYPE_CHECKING: from bot import ClusterBot import discord from better_profanity import profanity -from consts import SubmissionMode +from consts import CUDA_FLAGS, GPU_TO_SM, SubmissionMode from discord import app_commands +from discord.app_commands import Choice from discord.ext import commands from report import generate_report from run_eval import FullResult @@ -228,3 +232,179 @@ async def _run_submission( def _get_arch(self, gpu_type: app_commands.Choice[str]): raise NotImplementedError() + + async def generate_ptx_code(self, source_code: str, gpu_type: str, include_sass: bool = False) -> tuple[bool, str]: + """ + Generate PTX code for a CUDA submission. + + Args: + source_code (str): The CUDA source code + gpu_type (str): The GPU architecture to target + include_sass (bool): Whether to include SASS assembly code + + Returns: + tuple[bool, str]: Success status and the PTX output or error message + """ + # Get the SM architecture code for the specified GPU type + arch = GPU_TO_SM.get(gpu_type) + if not arch: + return False, f"Unknown GPU type: {gpu_type}. Available types: {', '.join(GPU_TO_SM.keys())}" + + try: + # Create a temporary directory for the compilation + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + source_file = temp_path / "submission.cu" + + # Write the source code to a file + source_file.write_text(source_code) + + # Prepare the compilation command with PTX output flag + ptx_flags = CUDA_FLAGS.copy() + ["-ptx"] + + # Add sass generation flag if requested + if include_sass: + ptx_flags.append("-Xptxas=-v") # Verbose output with sass info + + arch_flag = f"-gencode=arch=compute_{arch},code=compute_{arch}" + + command = ["nvcc"] + ptx_flags + [str(source_file), arch_flag, "-o", str(temp_path / "output.ptx")] + + # Check if nvcc is available + nvcc_check = subprocess.run(["which", "nvcc"], capture_output=True, text=True) + if nvcc_check.returncode != 0: + return False, "NVCC (CUDA compiler) not found. Is CUDA installed?" + + # Run the compilation + process = subprocess.run(command, capture_output=True, text=True) + + # Prepare the output with both stderr (for SASS if requested) and the PTX file + result = "" + + # Include compilation output which contains SASS information + if include_sass and process.stderr: + result += "SASS Assembly Information:\n" + result += "-" * 40 + "\n" + result += process.stderr + "\n" + result += "-" * 40 + "\n\n" + + if process.returncode != 0: + # Compilation failed + return False, f"PTX generation failed:\n{process.stderr}" + + # Read the PTX file + ptx_file = temp_path / "output.ptx" + if ptx_file.exists(): + result += "PTX Code:\n" + result += "-" * 40 + "\n" + result += ptx_file.read_text() + return True, result + else: + return False, "PTX file was not generated" + except Exception as e: + return False, f"Error generating PTX: {str(e)}" + + @app_commands.command(name="ptx") + @app_commands.describe( + submission="The CUDA submission file (.cu extension)", + gpu_type="The GPU architecture to target", + include_sass="Whether to include SASS/assembly output", + as_file="Return the PTX code as a downloadable file instead of text messages" + ) + @app_commands.choices( + gpu_type=[ + Choice(name=gpu, value=gpu) for gpu in GPU_TO_SM.keys() + ] + ) + @with_error_handling + async def ptx_command(self, interaction: discord.Interaction, + submission: discord.Attachment, + gpu_type: Choice[str] = None, + include_sass: bool = False, + as_file: bool = False): + """ + Generate PTX code from a CUDA submission. + + Parameters + ------------ + submission: File + The CUDA submission file (.cu extension) + gpu_type: Choice[str] + The GPU architecture to target + include_sass: bool + Whether to include SASS assembly code in the output + as_file: bool + Return the PTX code as a downloadable file instead of text messages + """ + if not interaction.response.is_done(): + await interaction.response.defer() + + # Validate the file extension + if not submission.filename.endswith('.cu'): + await send_discord_message(interaction, "❌ Only .cu file extensions are supported for PTX generation") + return + + # Set default GPU type to T4 if not specified + target_gpu = gpu_type.value if gpu_type else "T4" + + try: + # Read the submission file + content = await submission.read() + source_code = content.decode('utf-8') + + # Create a thread for the PTX generation + thread_name = f"PTX Generation - {submission.filename} - {target_gpu}" + if include_sass: + thread_name += " with SASS" + + thread = await interaction.channel.create_thread( + name=thread_name, + type=discord.ChannelType.public_thread, + ) + + await thread.send(f"Generating PTX code for {submission.filename} targeting {target_gpu}..." + + (" (including SASS output)" if include_sass else "")) + + # Generate the PTX code + success, result = await self.generate_ptx_code(source_code, target_gpu, include_sass) + + if success: + if as_file: + # Create a temporary file containing the PTX output + with tempfile.NamedTemporaryFile('w', suffix='.ptx', delete=False) as temp_file: + temp_file.write(result) + temp_file_path = temp_file.name + + # Get the base filename without extension + base_filename = Path(submission.filename).stem + output_filename = f"{base_filename}_{target_gpu}.ptx" + + # Send the file + await thread.send( + f"PTX code for {submission.filename} targeting {target_gpu}:", + file=discord.File(temp_file_path, filename=output_filename) + ) + + # Remove the temporary file + Path(temp_file_path).unlink(missing_ok=True) + else: + # Split the PTX code into chunks if it's too long for Discord + max_msg_length = 1900 # Slightly less than 2000 to account for markdown + chunks = [result[i:i+max_msg_length] for i in range(0, len(result), max_msg_length)] + + for i, chunk in enumerate(chunks): + await thread.send(f"```{chunk}```") + + # Send a summary message + await thread.send(f"✅ PTX code generation complete for {target_gpu} GPU" + + (" with SASS assembly" if include_sass else "")) + else: + # Send the error message + await thread.send(f"❌ Failed to generate PTX code: {result}") + + # Notify user in the original channel + await send_discord_message(interaction, f"PTX generation for {submission.filename} is complete. Check the thread for results.") + + except Exception as e: + logger.error(f"Error generating PTX: {e}", exc_info=True) + await send_discord_message(interaction, f"❌ Error generating PTX: {str(e)}") diff --git a/src/discord-cluster-manager/cogs/verify_run_cog.py b/src/discord-cluster-manager/cogs/verify_run_cog.py index 83862f01..7e236477 100644 --- a/src/discord-cluster-manager/cogs/verify_run_cog.py +++ b/src/discord-cluster-manager/cogs/verify_run_cog.py @@ -1,8 +1,6 @@ import asyncio import datetime import re -import subprocess -import tempfile import uuid from pathlib import Path from unittest.mock import AsyncMock @@ -13,7 +11,7 @@ from cogs.github_cog import GitHubCog from cogs.leaderboard_cog import LeaderboardSubmitCog from cogs.modal_cog import ModalCog -from consts import CUDA_FLAGS, GPU_TO_SM, SubmissionMode +from consts import SubmissionMode from discord import app_commands from discord.app_commands import Choice from discord.ext import commands @@ -263,181 +261,6 @@ async def verify_submission( # noqa: C901 if report_success: reports.append(f"✅ {run_id:20} {mode.name} behaved as expected") - async def generate_ptx_code(self, source_code: str, gpu_type: str, include_sass: bool = False) -> tuple[bool, str]: - """ - Generate PTX code for a CUDA submission. - - Args: - source_code (str): The CUDA source code - gpu_type (str): The GPU architecture to target - include_sass (bool): Whether to include SASS assembly code - - Returns: - tuple[bool, str]: Success status and the PTX output or error message - """ - # Get the SM architecture code for the specified GPU type - arch = GPU_TO_SM.get(gpu_type) - if not arch: - return False, f"Unknown GPU type: {gpu_type}. Available types: {', '.join(GPU_TO_SM.keys())}" - - try: - # Create a temporary directory for the compilation - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - source_file = temp_path / "submission.cu" - - # Write the source code to a file - source_file.write_text(source_code) - - # Prepare the compilation command with PTX output flag - ptx_flags = CUDA_FLAGS.copy() + ["-ptx"] - - # Add sass generation flag if requested - if include_sass: - ptx_flags.append("-Xptxas=-v") # Verbose output with sass info - - arch_flag = f"-gencode=arch=compute_{arch},code=compute_{arch}" - - command = ["nvcc"] + ptx_flags + [str(source_file), arch_flag, "-o", str(temp_path / "output.ptx")] - - # Check if nvcc is available - nvcc_check = subprocess.run(["which", "nvcc"], capture_output=True, text=True) - if nvcc_check.returncode != 0: - return False, "NVCC (CUDA compiler) not found. Is CUDA installed?" - - # Run the compilation - process = subprocess.run(command, capture_output=True, text=True) - - # Prepare the output with both stderr (for SASS if requested) and the PTX file - result = "" - - # Include compilation output which contains SASS information - if include_sass and process.stderr: - result += "SASS Assembly Information:\n" - result += "-" * 40 + "\n" - result += process.stderr + "\n" - result += "-" * 40 + "\n\n" - - if process.returncode != 0: - # Compilation failed - return False, f"PTX generation failed:\n{process.stderr}" - - # Read the PTX file - ptx_file = temp_path / "output.ptx" - if ptx_file.exists(): - result += "PTX Code:\n" - result += "-" * 40 + "\n" - result += ptx_file.read_text() - return True, result - else: - return False, "PTX file was not generated" - except Exception as e: - return False, f"Error generating PTX: {str(e)}" - - @app_commands.command(name="ptx") - @app_commands.describe( - submission="The CUDA submission file (.cu extension)", - gpu_type="The GPU architecture to target", - include_sass="Whether to include SASS/assembly output", - as_file="Return the PTX code as a downloadable file instead of text messages" - ) - @app_commands.choices( - gpu_type=[ - Choice(name=gpu, value=gpu) for gpu in GPU_TO_SM.keys() - ] - ) - @with_error_handling - async def ptx_command(self, interaction: discord.Interaction, - submission: discord.Attachment, - gpu_type: Choice[str] = None, - include_sass: bool = False, - as_file: bool = False): - """ - Generate PTX code from a CUDA submission. - - Parameters - ------------ - submission: File - The CUDA submission file (.cu extension) - gpu_type: Choice[str] - The GPU architecture to target - include_sass: bool - Whether to include SASS assembly code in the output - as_file: bool - Return the PTX code as a downloadable file instead of text messages - """ - if not interaction.response.is_done(): - await interaction.response.defer() - - # Validate the file extension - if not submission.filename.endswith('.cu'): - await send_discord_message(interaction, "❌ Only .cu file extensions are supported for PTX generation") - return - - # Set default GPU type to T4 if not specified - target_gpu = gpu_type.value if gpu_type else "T4" - - try: - # Read the submission file - content = await submission.read() - source_code = content.decode('utf-8') - - # Create a thread for the PTX generation - thread_name = f"PTX Generation - {submission.filename} - {target_gpu}" - if include_sass: - thread_name += " with SASS" - - thread = await interaction.channel.create_thread( - name=thread_name, - type=discord.ChannelType.public_thread, - ) - - await thread.send(f"Generating PTX code for {submission.filename} targeting {target_gpu}..." + - (" (including SASS output)" if include_sass else "")) - - # Generate the PTX code - success, result = await self.generate_ptx_code(source_code, target_gpu, include_sass) - - if success: - if as_file: - # Create a temporary file containing the PTX output - with tempfile.NamedTemporaryFile('w', suffix='.ptx', delete=False) as temp_file: - temp_file.write(result) - temp_file_path = temp_file.name - - # Get the base filename without extension - base_filename = Path(submission.filename).stem - output_filename = f"{base_filename}_{target_gpu}.ptx" - - # Send the file - await thread.send( - f"PTX code for {submission.filename} targeting {target_gpu}:", - file=discord.File(temp_file_path, filename=output_filename) - ) - - # Remove the temporary file - Path(temp_file_path).unlink(missing_ok=True) - else: - # Split the PTX code into chunks if it's too long for Discord - max_msg_length = 1900 # Slightly less than 2000 to account for markdown - chunks = [result[i:i+max_msg_length] for i in range(0, len(result), max_msg_length)] - - for i, chunk in enumerate(chunks): - await thread.send(f"```{chunk}```") - - # Send a summary message - await thread.send(f"✅ PTX code generation complete for {target_gpu} GPU" + - (" with SASS assembly" if include_sass else "")) - else: - # Send the error message - await thread.send(f"❌ Failed to generate PTX code: {result}") - - # Notify user in the original channel - await send_discord_message(interaction, f"PTX generation for {submission.filename} is complete. Check the thread for results.") - - except Exception as e: - logger.error(f"Error generating PTX: {e}", exc_info=True) - await send_discord_message(interaction, f"❌ Error generating PTX: {str(e)}") @app_commands.command(name="verifyruns") async def verify_runs(self, interaction: discord.Interaction): From 1dffe5d2caeb099cbb21c5d05c6293d221387a35 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 5 Mar 2025 16:53:33 -0800 Subject: [PATCH 4/4] ruff --- .../cogs/submit_cog.py | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/src/discord-cluster-manager/cogs/submit_cog.py b/src/discord-cluster-manager/cogs/submit_cog.py index c4f813be..ec27e37f 100644 --- a/src/discord-cluster-manager/cogs/submit_cog.py +++ b/src/discord-cluster-manager/cogs/submit_cog.py @@ -1,8 +1,8 @@ -from enum import Enum -from typing import TYPE_CHECKING, Optional, Tuple, Type -import tempfile import subprocess +import tempfile +from enum import Enum from pathlib import Path +from typing import TYPE_CHECKING, Optional, Tuple, Type if TYPE_CHECKING: from bot import ClusterBot @@ -232,7 +232,7 @@ async def _run_submission( def _get_arch(self, gpu_type: app_commands.Choice[str]): raise NotImplementedError() - + async def generate_ptx_code(self, source_code: str, gpu_type: str, include_sass: bool = False) -> tuple[bool, str]: """ Generate PTX code for a CUDA submission. @@ -249,49 +249,49 @@ async def generate_ptx_code(self, source_code: str, gpu_type: str, include_sass: arch = GPU_TO_SM.get(gpu_type) if not arch: return False, f"Unknown GPU type: {gpu_type}. Available types: {', '.join(GPU_TO_SM.keys())}" - + try: # Create a temporary directory for the compilation with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) source_file = temp_path / "submission.cu" - + # Write the source code to a file source_file.write_text(source_code) - + # Prepare the compilation command with PTX output flag ptx_flags = CUDA_FLAGS.copy() + ["-ptx"] - + # Add sass generation flag if requested if include_sass: ptx_flags.append("-Xptxas=-v") # Verbose output with sass info - + arch_flag = f"-gencode=arch=compute_{arch},code=compute_{arch}" - + command = ["nvcc"] + ptx_flags + [str(source_file), arch_flag, "-o", str(temp_path / "output.ptx")] - + # Check if nvcc is available nvcc_check = subprocess.run(["which", "nvcc"], capture_output=True, text=True) if nvcc_check.returncode != 0: return False, "NVCC (CUDA compiler) not found. Is CUDA installed?" - + # Run the compilation process = subprocess.run(command, capture_output=True, text=True) - + # Prepare the output with both stderr (for SASS if requested) and the PTX file result = "" - + # Include compilation output which contains SASS information if include_sass and process.stderr: result += "SASS Assembly Information:\n" result += "-" * 40 + "\n" result += process.stderr + "\n" result += "-" * 40 + "\n\n" - + if process.returncode != 0: # Compilation failed return False, f"PTX generation failed:\n{process.stderr}" - + # Read the PTX file ptx_file = temp_path / "output.ptx" if ptx_file.exists(): @@ -303,7 +303,7 @@ async def generate_ptx_code(self, source_code: str, gpu_type: str, include_sass: return False, "PTX file was not generated" except Exception as e: return False, f"Error generating PTX: {str(e)}" - + @app_commands.command(name="ptx") @app_commands.describe( submission="The CUDA submission file (.cu extension)", @@ -317,8 +317,8 @@ async def generate_ptx_code(self, source_code: str, gpu_type: str, include_sass: ] ) @with_error_handling - async def ptx_command(self, interaction: discord.Interaction, - submission: discord.Attachment, + async def ptx_command(self, interaction: discord.Interaction, + submission: discord.Attachment, gpu_type: Choice[str] = None, include_sass: bool = False, as_file: bool = False): @@ -338,73 +338,73 @@ async def ptx_command(self, interaction: discord.Interaction, """ if not interaction.response.is_done(): await interaction.response.defer() - + # Validate the file extension if not submission.filename.endswith('.cu'): await send_discord_message(interaction, "❌ Only .cu file extensions are supported for PTX generation") return - + # Set default GPU type to T4 if not specified target_gpu = gpu_type.value if gpu_type else "T4" - + try: # Read the submission file content = await submission.read() source_code = content.decode('utf-8') - + # Create a thread for the PTX generation thread_name = f"PTX Generation - {submission.filename} - {target_gpu}" if include_sass: thread_name += " with SASS" - + thread = await interaction.channel.create_thread( name=thread_name, type=discord.ChannelType.public_thread, ) - - await thread.send(f"Generating PTX code for {submission.filename} targeting {target_gpu}..." + + + await thread.send(f"Generating PTX code for {submission.filename} targeting {target_gpu}..." + (" (including SASS output)" if include_sass else "")) - + # Generate the PTX code success, result = await self.generate_ptx_code(source_code, target_gpu, include_sass) - + if success: if as_file: # Create a temporary file containing the PTX output with tempfile.NamedTemporaryFile('w', suffix='.ptx', delete=False) as temp_file: temp_file.write(result) temp_file_path = temp_file.name - + # Get the base filename without extension base_filename = Path(submission.filename).stem output_filename = f"{base_filename}_{target_gpu}.ptx" - + # Send the file await thread.send( f"PTX code for {submission.filename} targeting {target_gpu}:", file=discord.File(temp_file_path, filename=output_filename) ) - + # Remove the temporary file Path(temp_file_path).unlink(missing_ok=True) else: # Split the PTX code into chunks if it's too long for Discord max_msg_length = 1900 # Slightly less than 2000 to account for markdown chunks = [result[i:i+max_msg_length] for i in range(0, len(result), max_msg_length)] - + for i, chunk in enumerate(chunks): await thread.send(f"```{chunk}```") - + # Send a summary message - await thread.send(f"✅ PTX code generation complete for {target_gpu} GPU" + + await thread.send(f"✅ PTX code generation complete for {target_gpu} GPU" + (" with SASS assembly" if include_sass else "")) else: # Send the error message await thread.send(f"❌ Failed to generate PTX code: {result}") - + # Notify user in the original channel await send_discord_message(interaction, f"PTX generation for {submission.filename} is complete. Check the thread for results.") - + except Exception as e: logger.error(f"Error generating PTX: {e}", exc_info=True) await send_discord_message(interaction, f"❌ Error generating PTX: {str(e)}")