Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,7 @@ dmypy.json
weights
save/
checkpoints/
runs/
runs/

# Claude Code
.claude/*
5,827 changes: 5,827 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

91 changes: 91 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
[tool.poetry]
name = "sciglm"
version = "1.0.0"
description = "SciGLM: Science-focused language model for scientific tasks"
authors = ["THUDM <team@thudm.ai>"]
readme = "README.md"
packages = []

[tool.poetry.dependencies]
python = ">=3.9,<3.9.7 || >3.9.7,<4.0"
protobuf = "*"
transformers = "4.30.2"
cpm_kernels = "*"
torch = ">=2.0"
gradio = "*"
mdtex2html = "*"
sentencepiece = "*"
accelerate = "*"
sse-starlette = "*"
streamlit = ">=1.24.0"
rouge_chinese = "*"
jieba = "*"
datasets = "*"
nltk = "*"
wandb = "*"

[tool.poetry.group.test.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.12.0"


[tool.pytest.ini_options]
minversion = "6.0"
addopts = [
"-ra",
"--strict-markers",
"--strict-config",
"--cov=.",
"--cov-report=term-missing:skip-covered",
"--cov-report=html:htmlcov",
"--cov-report=xml",
"--cov-fail-under=80",
]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"unit: marks tests as unit tests (deselect with '-m \"not unit\"')",
"integration: marks tests as integration tests (deselect with '-m \"not integration\"')",
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
]

[tool.coverage.run]
source = ["."]
omit = [
"*/tests/*",
"*/test_*",
"*/.venv/*",
"*/venv/*",
"*/env/*",
"*/__pycache__/*",
"*/build/*",
"*/dist/*",
"*/.*",
"setup.py",
]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod",
]
show_missing = true
precision = 2

[tool.coverage.html]
directory = "htmlcov"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Tests package initialization
171 changes: 171 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Shared pytest fixtures for the SciGLM project.

This module contains common fixtures that can be used across all test files.
"""

import os
import tempfile
import shutil
from pathlib import Path
from typing import Generator, Dict, Any
import pytest


@pytest.fixture
def temp_dir() -> Generator[Path, None, None]:
"""
Create a temporary directory for testing.

Yields:
Path: Path to the temporary directory
"""
temp_path = Path(tempfile.mkdtemp())
try:
yield temp_path
finally:
shutil.rmtree(temp_path, ignore_errors=True)


@pytest.fixture
def temp_file() -> Generator[Path, None, None]:
"""
Create a temporary file for testing.

Yields:
Path: Path to the temporary file
"""
temp_fd, temp_path = tempfile.mkstemp()
temp_file_path = Path(temp_path)
try:
os.close(temp_fd)
yield temp_file_path
finally:
if temp_file_path.exists():
temp_file_path.unlink()


@pytest.fixture
def sample_config() -> Dict[str, Any]:
"""
Provide a sample configuration dictionary for testing.

Returns:
Dict[str, Any]: Sample configuration
"""
return {
"model_name": "test_model",
"max_length": 512,
"temperature": 0.7,
"batch_size": 4,
"learning_rate": 1e-5,
"epochs": 3,
"save_path": "/tmp/test_model",
"device": "cpu"
}


@pytest.fixture
def sample_text_data() -> list[str]:
"""
Provide sample text data for testing.

Returns:
list[str]: List of sample text strings
"""
return [
"This is a sample text for testing purposes.",
"Another example of text data used in tests.",
"Scientific text about machine learning and AI.",
"Test data for natural language processing tasks."
]


@pytest.fixture
def mock_model_config() -> Dict[str, Any]:
"""
Provide mock model configuration for testing.

Returns:
Dict[str, Any]: Mock model configuration
"""
return {
"vocab_size": 50000,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"layer_norm_eps": 1e-12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"use_cache": True,
"classifier_dropout": None,
}


@pytest.fixture(autouse=True)
def clean_environment():
"""
Clean up environment variables before and after each test.

This fixture automatically runs for every test to ensure a clean environment.
"""
# Store original environment
original_env = dict(os.environ)

# Clean up test-related environment variables
test_vars = [var for var in os.environ.keys() if var.startswith('TEST_')]
for var in test_vars:
os.environ.pop(var, None)

yield

# Restore original environment
os.environ.clear()
os.environ.update(original_env)


@pytest.fixture
def mock_dataset_path(temp_dir: Path) -> Path:
"""
Create a mock dataset file for testing.

Args:
temp_dir: Temporary directory fixture

Returns:
Path: Path to the mock dataset file
"""
dataset_file = temp_dir / "mock_dataset.json"
mock_data = [
{"input": "What is AI?", "output": "AI stands for Artificial Intelligence."},
{"input": "Explain machine learning.", "output": "Machine learning is a subset of AI."},
{"input": "What is deep learning?", "output": "Deep learning uses neural networks."}
]

import json
with open(dataset_file, 'w', encoding='utf-8') as f:
json.dump(mock_data, f, ensure_ascii=False, indent=2)

return dataset_file


@pytest.fixture
def suppress_warnings():
"""
Suppress common warnings during testing.
"""
import warnings

# Suppress specific warnings that commonly appear during testing
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
warnings.filterwarnings("ignore", category=DeprecationWarning)

yield

# Reset warning filters
warnings.resetwarnings()
1 change: 1 addition & 0 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Integration tests package initialization
Loading