Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions api/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,134 @@ def from_neo(cls, neo_blocks, data_file_ur):
]
}
}

# ============================================================
# API v2 models - Simplified structure endpoint
# ============================================================

class SignalInfo(BaseModel):
"""Metadata about an analog signal, without the actual data."""
signal_id: int
name: str
units: str
sampling_rate: float | None = None
sampling_rate_units: str | None = None
t_start: float
t_stop: float
duration: float
n_channels: int
n_samples: int
is_irregular: bool = False

@classmethod
def from_neo(cls, signal, signal_id, irregular=False):
if isinstance(signal, proxyobjects.BaseProxy):
signal = signal.load()
data = {
"signal_id": signal_id,
"name": signal.name or f"Signal {signal_id}",
"units": str(signal.units.dimensionality),
"t_start": float(signal.t_start.magnitude),
"t_stop": float(signal.t_stop.magnitude),
"duration": float((signal.t_stop - signal.t_start).magnitude),
"n_channels": signal.shape[1] if len(signal.shape) > 1 else 1,
"n_samples": signal.shape[0],
"is_irregular": irregular,
}
if not irregular:
data["sampling_rate"] = float(signal.sampling_rate.magnitude)
data["sampling_rate_units"] = str(signal.sampling_rate.units.dimensionality)
return cls(**data)


class SpikeTrainInfo(BaseModel):
"""Metadata about a spike train, without the actual data."""
train_id: int
name: str
units: str
t_stop: float
count: int

@classmethod
def from_neo(cls, spike_train, train_id):
if isinstance(spike_train, proxyobjects.BaseProxy):
spike_train = spike_train.load()
return cls(
train_id=train_id,
name=spike_train.name or f"Unit {train_id}",
units=str(spike_train.units.dimensionality),
t_stop=float(spike_train.t_stop.magnitude),
count=len(spike_train.times),
)


class SegmentStructure(BaseModel):
"""Structure of a segment with signal metadata."""
segment_id: int
name: str
description: str
rec_datetime: datetime | None = None
analog_signals: list[SignalInfo]
irregular_signals: list[SignalInfo]
spike_trains: list[SpikeTrainInfo]

@classmethod
def from_neo(cls, neo_segment, segment_id):
return cls(
segment_id=segment_id,
name=neo_segment.name or f"Segment {segment_id}",
description=neo_segment.description or "",
rec_datetime=parse_datetime(neo_segment.rec_datetime),
analog_signals=[
SignalInfo.from_neo(sig, i)
for i, sig in enumerate(neo_segment.analogsignals)
],
irregular_signals=[
SignalInfo.from_neo(sig, i, irregular=True)
for i, sig in enumerate(neo_segment.irregularlysampledsignals)
],
spike_trains=[
SpikeTrainInfo.from_neo(st, i)
for i, st in enumerate(neo_segment.spiketrains)
],
)


class BlockStructure(BaseModel):
"""Structure of a block with all segment metadata."""
block_id: int
name: str
description: str
rec_datetime: datetime | None = None
annotations: dict[str, str]
segments: list[SegmentStructure]

@classmethod
def from_neo(cls, neo_block, block_id):
return cls(
block_id=block_id,
name=neo_block.name or f"Block {block_id}",
description=neo_block.description or "",
rec_datetime=parse_datetime(neo_block.rec_datetime),
annotations=sanitise_annotations(neo_block.annotations),
segments=[
SegmentStructure.from_neo(seg, i)
for i, seg in enumerate(neo_block.segments)
],
)


class FileStructure(BaseModel):
"""Complete structure of a data file."""
url: str
blocks: list[BlockStructure]

@classmethod
def from_neo(cls, neo_blocks, url):
return cls(
url=str(url),
blocks=[
BlockStructure.from_neo(block, i)
for i, block in enumerate(neo_blocks)
],
)
4 changes: 3 additions & 1 deletion api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from starlette.exceptions import HTTPException as StarletteHTTPException

from .resources.v1 import router as router_v1
from .resources.v2 import router as router_v2
from .metadata import title, description


app = FastAPI(
title=title,
description=description,
version="1.8",
version="2.0",
openapi_url="/api/openapi.json",
docs_url="/api/docs",
redoc_url="/api/redoc"
Expand Down Expand Up @@ -63,3 +64,4 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE

app.include_router(router_v1, prefix="/api/v1")
app.include_router(router_v1, prefix="/api")
app.include_router(router_v2, prefix="/api/v2")
209 changes: 209 additions & 0 deletions api/resources/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""
Implementation of endpoints, API version 2.

Simplified API that reduces the number of sequential calls needed
to discover and access data in a file.

Copyright CNRS 2023-2026
Authors: Andrew P. Davison, Thierry Djebouri
Licence: MIT (see LICENSE)
"""

from typing import Annotated
from pydantic import HttpUrl, PositiveInt

from fastapi import Query, HTTPException, APIRouter, status

from ..metadata import title, description
from ..data_models import (
IOModule,
AnalogSignal,
SpikeTrain,
FileStructure,
)
from ..data_handler import load_blocks

router = APIRouter()


@router.get("/")
async def info():
"""Return information about the API."""
return {
"title": title,
"description": description.strip(),
"version": 2.0,
}


@router.get("/structure/")
async def get_structure(
url: Annotated[
HttpUrl, Query(description="Location of a data file that can be read by Neo.")
],
type: Annotated[
IOModule,
Query(
description=(
"Specify a specific Neo IO module that should be used to open the data file. "
"If not provided, Neo will try to determine which module to use."
)
),
] = None,
refresh_cache: Annotated[
bool,
Query(
description=(
"If true, any previously cached version of the file will be "
"invalidated and the file will be re-downloaded from the source."
)
),
] = False,
) -> FileStructure:
"""
Return the complete structure of a data file in a single call.

This includes all blocks, segments, and metadata about the signals
and spike trains in each segment (but not the actual data).

Replaces the v1 /blockdata/ and /segmentdata/ endpoints.
"""
blocks = load_blocks(str(url), type, refresh_cache=refresh_cache)
return FileStructure.from_neo(blocks, url)


@router.get("/data/analogsignal/")
async def get_analogsignal_data(
url: Annotated[
HttpUrl, Query(description="Location of a data file that can be read by Neo.")
],
segment_id: Annotated[
int,
Query(description="Index of the segment containing the signal."),
],
signal_id: Annotated[
int, Query(description="Index of the signal within the segment.")
],
block_id: Annotated[
int,
Query(description="Index of the block containing the segment."),
] = 0,
type: Annotated[
IOModule,
Query(
description=(
"Specify a specific Neo IO module that should be used to open the data file. "
"If not provided, Neo will try to determine which module to use."
)
),
] = None,
down_sample_factor: Annotated[
PositiveInt | None | str,
Query(
description=(
"Factor by which data should be downsampled prior to loading. "
"Useful for faster loading of large files."
)
),
] = 1,
refresh_cache: Annotated[
bool,
Query(
description=(
"If true, any previously cached version of the file will be "
"invalidated and the file will be re-downloaded from the source."
)
),
] = False,
) -> AnalogSignal:
"""
Get an analog signal including both data and metadata.

Use the /structure/ endpoint first to discover available signals.
"""
try:
block = load_blocks(str(url), type, refresh_cache=refresh_cache)[block_id]
except IndexError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"block_id {block_id} is out of range.",
)
try:
segment = block.segments[segment_id]
except IndexError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"segment_id {segment_id} is out of range.",
)
if len(segment.analogsignals) > 0:
container = segment.analogsignals
else:
container = segment.irregularlysampledsignals
try:
signal = container[signal_id]
except IndexError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"signal_id {signal_id} is out of range.",
)
try:
return AnalogSignal.from_neo(signal, down_sample_factor)
except (ValueError, OSError) as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(err),
)


@router.get("/data/spiketrains/")
async def get_spiketrain_data(
url: Annotated[
HttpUrl, Query(description="Location of a data file that can be read by Neo.")
],
segment_id: Annotated[
int,
Query(description="Index of the segment containing the spike trains."),
],
block_id: Annotated[
int,
Query(description="Index of the block containing the segment."),
] = 0,
type: Annotated[
IOModule,
Query(
description=(
"Specify a specific Neo IO module that should be used to open the data file. "
"If not provided, Neo will try to determine which module to use."
)
),
] = None,
refresh_cache: Annotated[
bool,
Query(
description=(
"If true, any previously cached version of the file will be "
"invalidated and the file will be re-downloaded from the source."
)
),
] = False,
) -> dict[str, SpikeTrain]:
"""
Get all spike trains from a given segment.

Use the /structure/ endpoint first to discover available spike trains.
"""
try:
block = load_blocks(str(url), type, refresh_cache=refresh_cache)[block_id]
except IndexError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"block_id {block_id} is out of range.",
)
try:
segment = block.segments[segment_id]
except IndexError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"segment_id {segment_id} is out of range.",
)
return {str(i): SpikeTrain.from_neo(st) for i, st in enumerate(segment.spiketrains)}