Skip to content

Commit 7a12150

Browse files
committed
u
1 parent 0a25aad commit 7a12150

File tree

2 files changed

+168
-1
lines changed

2 files changed

+168
-1
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
repos:
1616
# Standard hooks
1717
- repo: https://github.com/pre-commit/pre-commit-hooks
18-
rev: v4.5.0
18+
rev: v6.0.0
1919
hooks:
2020
- id: check-case-conflict
2121
- id: check-docstring-first

add_tags.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from pathlib import Path
2+
from pydantic_ai import Agent, ModelSettings, capture_run_messages
3+
from pydantic_ai.providers.openai import OpenAIProvider
4+
from pydantic_ai.models.openai import OpenAIChatModel
5+
from pydantic import BaseModel
6+
from annotated_types import Gt, Lt
7+
from typing import Annotated
8+
import json
9+
import nbformat
10+
from typing import Literal
11+
import asyncio
12+
import logging
13+
from rich.logging import RichHandler
14+
15+
logging.basicConfig(
16+
level="INFO",
17+
format="%(message)s",
18+
datefmt="[%X]",
19+
handlers=[RichHandler(rich_tracebacks=True)],
20+
)
21+
22+
log = logging.getLogger(__name__)
23+
24+
model = OpenAIChatModel(
25+
"",
26+
provider=OpenAIProvider(
27+
base_url="http://localhost:8080/v1",
28+
),
29+
settings=ModelSettings(temperature=0.5, max_tokens=1000),
30+
)
31+
32+
valid_tags_raw = """
33+
physics: Post is related to physics, especially particle physics.
34+
science: Post is about science other than physics.
35+
programming: The post is primarily about programming, discussing language features or libraries.
36+
high performance computing: Post is about running software efficiently and fast, typically dealing with benchmarks.
37+
statistics: Post is related to statistics.
38+
llm: Post is related to LLMs (Large Language Models) or uses LLMs, for example through agents.
39+
philosophy: Post touches philosophy.
40+
engineering: Post is about engineering.
41+
opinion: Post expresses opinions.
42+
data analysis: Post is about data analysis.
43+
visualization: Post is primarily about data visualization.
44+
graphical design: Post is about graphical design.
45+
parsing: Post deals with parsing input.
46+
bootstrap: Post is about the bootstrap method in statistics.
47+
uncertainty analysis: Post is about the statistical uncertainty estimation, confidence interval estimation, or uncertainty propagation (uncertainty = error in this context).
48+
sWeights: Posts about sWeights or COWs (custom orthogonal weight functions).
49+
symbolic computation: Post uses symbolic computation with sympy.
50+
simulation: Post is about simulation of statistical or other processes.
51+
neural networks: Post is about (deep) neural networks.
52+
machine learning: Post is about machine learning other than with neural networks.
53+
prompt engineering: Post is about prompt engineering.
54+
web scraping: Post is about web scraping.
55+
environment: Post is about energy consumption and other topics that affect Earth's environment.
56+
"""
57+
58+
valid_tags = {
59+
v[0]: v[1] for v in (v.split(":") for v in valid_tags_raw.strip().split("\n"))
60+
}
61+
62+
63+
AllowedTags = Literal[*valid_tags]
64+
65+
66+
class TagWithConfidence(BaseModel):
67+
tag: AllowedTags # type:ignore
68+
confidence: Annotated[float, Gt(0), Lt(1)]
69+
70+
71+
tag_agent = Agent(
72+
model,
73+
output_type=list[TagWithConfidence],
74+
system_prompt="Return tags that match the provided post.",
75+
instructions=f"""
76+
Respond with a list of all tags that match the post.
77+
78+
All valid tags:
79+
80+
{"- ".join(f"{k}: {v}" for (k, v) in valid_tags.items())}
81+
82+
You must use one of these tags, you cannot invent new ones.
83+
""",
84+
)
85+
86+
87+
fn_tag_db = Path("tag_db.json")
88+
89+
if fn_tag_db.exists():
90+
with fn_tag_db.open(encoding="utf-8") as f:
91+
tag_db = json.load(f)
92+
else:
93+
tag_db = {}
94+
95+
96+
async def get_tags(fn: Path) -> set[str]:
97+
with open(fn, encoding="utf-8") as f:
98+
if fn.suffix == ".ipynb":
99+
# We clean the notebook before passing it to the LLM
100+
nb = nbformat.read(f, as_version=4)
101+
nb.metadata = {}
102+
for cell in nb.cells:
103+
if cell.cell_type == "code":
104+
cell.outputs = []
105+
cell.execution_count = None
106+
cell.metadata = {}
107+
doc = nbformat.writes(nb)
108+
elif fn.suffix == ".md":
109+
doc = f.read()
110+
111+
tag_input = f"{fn!s}:\n\n{doc}" # type:ignore
112+
113+
joined_tags = set()
114+
for i in range(3):
115+
# To get a more complete set of tags, we iterate the call.
116+
with capture_run_messages() as messages:
117+
try:
118+
result = await tag_agent.run(tag_input)
119+
log.info(f"{fn.name} [{i}] {result.output}")
120+
tags = set(x.tag for x in result.output if x.confidence >= 0.8)
121+
joined_tags |= tags
122+
log.debug(messages)
123+
except Exception:
124+
# If there is an error (typically a schema validation error),
125+
# print the messages for debugging.
126+
log.exception(messages)
127+
raise
128+
log.info(f"{fn.name} {joined_tags}")
129+
return joined_tags
130+
131+
132+
async def main():
133+
input_files = [Path(fn) for fn in Path("_posts").rglob("*.*")]
134+
135+
to_process = []
136+
for fn in input_files:
137+
if fn.suffix not in (".ipynb", ".md"):
138+
continue
139+
140+
# skip files that have been processed already
141+
if fn.name in tag_db:
142+
continue
143+
144+
to_process.append(fn)
145+
146+
try:
147+
for fn in sorted(to_process):
148+
tags = await get_tags(fn)
149+
if tags:
150+
# A sorted list is easier to diff if we update tags.
151+
tag_db[fn.name] = list(sorted(tags))
152+
else:
153+
log.error(f"No tags for {fn.name!r}")
154+
# save after every change, in case something breaks
155+
with fn_tag_db.open("w", encoding="utf-8") as f:
156+
json.dump(
157+
dict(sorted(tag_db.items(), key=lambda x: x[0].lower())),
158+
f,
159+
indent=2,
160+
)
161+
162+
except Exception:
163+
log.exception("Fatal error")
164+
raise SystemExit("Fatal error")
165+
166+
167+
asyncio.run(main())

0 commit comments

Comments
 (0)