Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions app/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import sqlite3
import os
from datetime import datetime
from pathlib import Path
from contextlib import contextmanager

# Use an absolute path relative to the app directory or project root
# Using the directory of this file to place the DB in 'app/' folder or similar
BASE_DIR = Path(__file__).resolve().parent.parent
DB_FILE = BASE_DIR / "data" / "chord_fingerprints.db"

# Ensure data directory exists
DB_FILE.parent.mkdir(parents=True, exist_ok=True)

def get_db_connection():
conn = sqlite3.connect(str(DB_FILE))
conn.row_factory = sqlite3.Row
return conn

@contextmanager
def db_cursor():
conn = get_db_connection()
try:
yield conn, conn.cursor()
conn.commit()
finally:
conn.close()

def init_db():
with db_cursor() as (conn, c):
# Table for segment-level fingerprints
c.execute('''
CREATE TABLE IF NOT EXISTS chord_fingerprints (
id INTEGER PRIMARY KEY AUTOINCREMENT,
phash TEXT NOT NULL,
chord_symbol TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')

# Table for file-level caching
c.execute('''
CREATE TABLE IF NOT EXISTS file_cache (
phash TEXT PRIMARY KEY,
progression_data TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')

def save_phash_chord_pair(phash: str, chord_symbol: str, cursor=None):
"""
Save pHash and chord symbol.
If cursor is provided, use it (for batch operations).
Otherwise, open a new connection.
"""
if cursor:
cursor.execute('INSERT INTO chord_fingerprints (phash, chord_symbol) VALUES (?, ?)',
(phash, chord_symbol))
else:
with db_cursor() as (conn, c):
c.execute('INSERT INTO chord_fingerprints (phash, chord_symbol) VALUES (?, ?)',
(phash, chord_symbol))

def get_cached_progression(phash: str):
"""Retrieve cached chord progression for a file pHash"""
with db_cursor() as (conn, c):
c.execute("SELECT progression_data FROM file_cache WHERE phash=?", (phash,))
row = c.fetchone()
if row:
return row['progression_data']
return None

def save_cached_progression(phash: str, progression_data: str):
"""Save chord progression to cache"""
with db_cursor() as (conn, c):
c.execute('''
INSERT OR REPLACE INTO file_cache (phash, progression_data)
VALUES (?, ?)
''', (phash, progression_data))
4 changes: 2 additions & 2 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class E2EBaseRequest(BaseModel):

class E2EBaseResult(BaseModel):
jobId: str
transcriptionUrl: str
separatedAudioUrl: str
transcriptionUrl: Optional[str] = None
separatedAudioUrl: Optional[str] = None
chordProgressionUrl: str
format: ChartFormat = ChartFormat.JSON

Expand Down
94 changes: 87 additions & 7 deletions app/tasks/chord_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,56 @@ def e2e_base_ready_task(self, audio_file_path: str, instrument: str):
from demucs.apply import apply_model
import torch
import torchaudio
import librosa
import numpy as np
from PIL import Image
import imagehash
from halmoni import MIDIAnalyzer, ChordDetector, KeyDetector, ChordProgression
import json
from app.database import (
save_phash_chord_pair, db_cursor, init_db,
get_cached_progression, save_cached_progression
)

self.update_progress(0, 100, "Starting E2E pipeline")

try:
output_dir = Path(f"./outputs/{job_id}")
output_dir.mkdir(parents=True, exist_ok=True)

# Initialize DB once
init_db()

# Check cache
self.update_progress(5, 100, "Checking cache")
try:
# Compute file pHash
y, sr = librosa.load(audio_file_path, sr=22050) # Use 22050Hz for pHash consistency
if len(y) > 0:
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)
mfcc_norm = (mfcc - mfcc.min()) / (mfcc.max() - mfcc.min() + 1e-6) * 255
mfcc_img = Image.fromarray(mfcc_norm.astype(np.uint8))
file_phash = str(imagehash.phash(mfcc_img))

cached_data = get_cached_progression(file_phash)
if cached_data:
self.update_progress(100, 100, "Found cached result")
# Save cached data to output file
chord_output_path = output_dir / "chord_progression.json"
with open(chord_output_path, 'w') as f:
f.write(cached_data)

return {
'jobId': job_id,
'transcriptionUrl': None,
'separatedAudioUrl': None,
'chordProgressionUrl': f'/outputs/{job_id}/chord_progression.json',
'format': 'json'
}
except Exception as e:
print(f"Warning: Cache check failed: {e}")
file_phash = None

# Step 1: Audio separation
self.update_progress(10, 100, "Separating audio")

Expand Down Expand Up @@ -347,12 +388,43 @@ def e2e_base_ready_task(self, audio_file_path: str, instrument: str):
time_windows = analyzer.get_time_windows(all_notes_flat, window_size=1.0)

chords = []
for window_start, window_notes in time_windows:
notes = analyzer.group_simultaneous_notes(window_notes)
if notes:
chord = detector.detect_chord_from_midi_notes(notes[0])
if chord:
chords.append(chord)

# Load audio for MFCC extraction
y, sr = librosa.load(str(separated_audio_path), sr=sample_rate)

# Use a single connection for all inserts
with db_cursor() as (conn, cursor):
for window_start, window_notes in time_windows:
notes = analyzer.group_simultaneous_notes(window_notes)
if notes:
chord = detector.detect_chord_from_midi_notes(notes[0])
if chord:
chords.append(chord)

# Extract MFCC and pHash
try:
# Extract audio segment (window_size=1.0)
start_sample = int(window_start * sr)
end_sample = int((window_start + 1.0) * sr)

if start_sample < len(y):
segment = y[start_sample:min(end_sample, len(y))]

if len(segment) > 0:
# Compute MFCC
mfcc = librosa.feature.mfcc(y=segment, sr=sr, n_mfcc=20)

# Normalize MFCC to 0-255 for image conversion
mfcc_norm = (mfcc - mfcc.min()) / (mfcc.max() - mfcc.min() + 1e-6) * 255
mfcc_img = Image.fromarray(mfcc_norm.astype(np.uint8))

# Compute pHash
phash = str(imagehash.phash(mfcc_img))

# Save to DB
save_phash_chord_pair(phash, str(chord), cursor=cursor)
except Exception as e:
print(f"Warning: Failed to generate pHash for chord {chord}: {e}")

key_detector = KeyDetector()
all_notes = all_notes_flat
Expand All @@ -365,8 +437,16 @@ def e2e_base_ready_task(self, audio_file_path: str, instrument: str):
'key': str(key) if key else None,
'chords': [{'symbol': str(chord), 'duration': 1.0} for chord in chords]
}
json_str = json.dumps(progression_data, indent=2)
with open(chord_output_path, 'w') as f:
json.dump(progression_data, f, indent=2)
f.write(json_str)

# Save to cache if we have a file pHash
if file_phash:
try:
save_cached_progression(file_phash, json_str)
except Exception as e:
print(f"Warning: Failed to save to cache: {e}")

chord_progression_url = f'/outputs/{job_id}/chord_progression.json'

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ jams
python-dotenv
deprecated
onnx>=1.19.0
ImageHash
Pillow