Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions not_tcp/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def from_header(cls, header: Header, body: bytes) -> "Packet":
)

def __len__(self):
return len(Header) + len(self.body)
return Header.length() + len(self.body)

def header(self) -> Header:
assert self.stream_id >= 0
Expand All @@ -92,6 +92,11 @@ def to_bytes(self) -> bytes:

@classmethod
def from_bytes(cls, buf: bytes) -> (Optional["Packet"], bytes):
# Trim null bytes that prefix a packet.
# This is (retroactively) why we start with stream ID 1!
while len(buf) > 0 and buf[0] == 0:
buf = buf[1:]

if len(buf) < Header.length():
return None, buf
header = Header.from_bytes(buf[:Header.length()])
Expand Down Expand Up @@ -171,17 +176,30 @@ async def run_outbound(self, number: int, writer: StreamWriter):
consumed = (buffer_len - len(buffer))
total_bytes += consumed
if p is None:
continue
buffer = rem
if packet_count == 0:
assert p.start
packet_count += 1
if not p.to_host:
# Ignore the packet
continue
writer.write(p.body)
await writer.drain()
if p.end:
break
if consumed > 0:
olog.debug(
f"consumed {consumed} padding zeros")

# No packet to consume. Get more data.
rcvd = self.recv() # Has its own timeout, but isn't async. So:
await asyncio.sleep(0)
buffer += rcvd
else:
olog.debug(f"consumed {consumed} bytes")

if packet_count == 0:
assert p.start
packet_count += 1
if not p.to_host:
# Ignore the packet
continue
writer.write(p.body)
await writer.drain()
if p.end:
break
olog.info(
f"device closed outbound connection for client {number}")

# TODO: Handle multiple streams.
writer.close()
await writer.wait_closed()
64 changes: 49 additions & 15 deletions not_tcp/not_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import session
from stream_utils import LimitForwarder

BUFFER_SIZE = 256


class Flags(Struct):
"""
Expand Down Expand Up @@ -119,7 +121,7 @@ def elaborate(self, platform):
m = Module()
connected = self.connected
input_buffer = m.submodules.input_buffer = SyncFIFOBuffered(
width=8, depth=256)
width=8, depth=BUFFER_SIZE)
m.d.comb += [
self.stop.data.payload.eq(input_buffer.r_stream.payload),
self.stop.data.valid.eq(input_buffer.r_stream.valid),
Expand Down Expand Up @@ -236,9 +238,8 @@ def __init__(self, stream_id):
def elaborate(self, platform):
m = Module()

# Each of these is big enough to buffer one full packet.
output_buffer = m.submodules.output_buffer = SyncFIFOBuffered(
width=8, depth=256)
width=8, depth=BUFFER_SIZE)

m.d.comb += [
output_buffer.w_stream.payload.eq(self.stop.data.payload),
Expand All @@ -254,23 +255,33 @@ def elaborate(self, platform):
output_limiter.start.eq(0),
self.bus.valid.eq(0)
]
connect(m, output_buffer.r_stream, output_limiter.inbound)

m.d.comb += [
self.bus.payload.eq(output_limiter.outbound.payload),
self.bus.valid.eq(output_limiter.outbound.valid),
output_limiter.outbound.ready.eq(self.bus.ready),
]
flags_layout = UnionLayout({"bytes": unsigned(8), "flags": Flags})

# Flags for outbound packet:
send_flags = Signal(flags_layout)
m.d.sync += send_flags.flags.to_host.eq(1)
send_len = Signal(8)
# Pad up to 64 bytes of zeros, to ensure packet delivery.
pad_len = Signal(8)

# Invariants:
# - End is set iff ~active and the buffer is empty.
# - We enter disconnected iff End is clear,
# i.e. End has been sent.

# Cases in which we want to send a packet:
with m.FSM(name="write"):
with m.State("disconnected"):
m.d.comb += Assert(~send_flags.flags.end)
m.next = "disconnected"
with m.If(self.stop.active):
# Immediately send a "start" packet.
m.d.sync += send_flags.flags.start.eq(1)
m.d.sync += send_flags.flags.end.eq(0)
m.d.sync += self.connected.eq(1)
m.next = "write-stream"

Expand All @@ -287,17 +298,27 @@ def elaborate(self, platform):
# Lock in the level as the length of this packet.
# We may send a short (zero-length) packet
# to start or end the connection.
m.d.sync += send_len.eq(output_buffer.r_level)
m.d.sync += send_len.eq(output_buffer.level)
# We send an explicit empty END packet.
m.d.sync += send_flags.flags.end.eq(
~self.stop.active &
(output_buffer.r_level == Const(0)))
(output_buffer.level == Const(0)))
with m.If(self.bus.ready):
m.next = "write-len"
with m.State("write-len"):
m.next = "write-len"
m.d.comb += self.bus.payload.eq(send_len)
m.d.comb += self.bus.valid.eq(1)

# Precompute padding on a cycle where we don't otherwise
# have much to do.
# In theory we only need to pad to 64...
# but that still doesn't get us the stop byte...
# so, double-padding?
m.d.sync += [
pad_len.eq(128 - ((3 + send_len) % 64))
]

with m.If(self.bus.ready):
m.next = "write-flags"
with m.State("write-flags"):
Expand All @@ -314,22 +335,35 @@ def elaborate(self, platform):
m.next = "write-body"
with m.State("write-body"):
m.next = "write-body"
connect(m, output_buffer.r_stream, output_limiter.inbound)

with m.If(output_limiter.done):
m.next = "zero-pad"
m.d.comb += [
output_limiter.count.eq(pad_len),
output_limiter.start.eq(1),
output_limiter.inbound.payload.eq(0),
output_limiter.inbound.valid.eq(1),
]

with m.State("zero-pad"):
m.next = "zero-pad"
m.d.comb += [
self.bus.payload.eq(output_limiter.outbound.payload),
self.bus.valid.eq(output_limiter.outbound.valid),
output_limiter.outbound.ready.eq(self.bus.ready),
output_limiter.inbound.payload.eq(0),
output_limiter.inbound.valid.eq(1),
]

with m.If(output_limiter.done):
m.d.sync += [
send_flags.flags.start.eq(0),
send_flags.flags.end.eq(0),
]
with m.If(send_flags.flags.end):
m.d.sync += self.connected.eq(0)
m.next = "disconnected"
with m.Else():
m.next = "write-stream"
# In either branch, we've sent a start packet.
m.d.sync += [
send_flags.flags.start.eq(0),
send_flags.flags.end.eq(0),
]

return m

Expand Down
5 changes: 3 additions & 2 deletions not_tcp/not_tcp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ async def driver(ctx):
while len(rcvd) > 0:
# All data should be packetized.
(p, remainder) = Packet.from_bytes(rcvd)
assert p is not None, f"remaining data: {rcvd}"
packets += [p]
assert (p is not None) or (len(remainder) == 0)
if p is not None:
packets += [p]
rcvd = remainder
bodies = bytes()
for i in range(len(packets)):
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ git+https://github.com/amaranth-lang/amaranth-soc@5c43cf58f15d9cd9c69ff83c979977
git+https://github.com/greatscottgadgets/luna@0.2.0

regex
pyserial

# Dev dependencies:
flake8
pytest
Expand Down
74 changes: 74 additions & 0 deletions serial_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import asyncio
import logging
import serial
import traceback

from not_tcp.host import StreamProxy


log = logging.getLogger(__name__)


class HostSerial(StreamProxy):
"""
The real deal: serving to the Fomu.
"""

def __init__(self, path):
"""
Arguments
--------
path: path to the serial device
"""
self._path = path
self._conn = None

def __enter__(self):
assert self._conn is None, "HostSerial is not reentrant"
self._conn = serial.Serial(
self._path, baudrate=9600, timeout=1, inter_byte_timeout=1)
assert self._conn.is_open, "HostSerial failed to open device"

return self

def __exit__(self, exc_type, exc_value, exe_traceback):
if exe_traceback is not None:
traceback.print_tb(exe_traceback)

self._conn.close()
self._conn = None

def send(self, b: bytes):
if self._conn is None or not self._conn.is_open:
log.error("HostSerial is not initialized")
return
self._conn.write(b)
self._conn.flush()
hex = b.hex(sep=' ')
log.debug(f"Wrote {len(b)} bytes to serial: ", hex)

def recv(self) -> bytes:
if self._conn is None or not self._conn.is_open:
log.error("HostSerial is not initialized")
return bytes()

# USB CDC is 64B, and we pad.
v = self._conn.read(64)
if v is None:
return bytes()
return v


async def amain(port):
# TODO: Scan for devices, set up or reset.
with HostSerial("/dev/ttyACM0") as srv:
server = await asyncio.start_server(
client_connected_cb=srv.client_connected,
host="localhost",
port=port
)
log.info(f"listening on port {port}\n")
await server.serve_forever()

if __name__ == "__main__":
asyncio.run(amain(3278))
8 changes: 6 additions & 2 deletions stream_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,21 @@ def collect_queue(self, q: queue.Queue[bytes], batch_size: int = 100,
stream = self._stream

async def collector(ctx):
ctx.set(stream.ready, 1)
ready = self.is_ready()
ctx.set(stream.ready, ready)
countup = 0
batch = bytes()

async for clk_edge, rst_value, valid, payload in ctx.tick().sample(
stream.valid, stream.payload):
if rst_value or (not clk_edge):
continue
if valid == 1:
if ready == 1 and valid == 1:
# We just transferred a payload byte.
batch += bytes([payload])
ready = self.is_ready()
else:
ready = ready | self.is_ready()
countup += 1

batch_exceeded = len(batch) >= batch_size
Expand Down