diff --git a/not_tcp/host.py b/not_tcp/host.py index 2c47667..4f6b2a6 100644 --- a/not_tcp/host.py +++ b/not_tcp/host.py @@ -79,7 +79,7 @@ def from_header(cls, header: Header, body: bytes) -> "Packet": ) def __len__(self): - return len(Header) + len(self.body) + return Header.length() + len(self.body) def header(self) -> Header: assert self.stream_id >= 0 @@ -92,6 +92,11 @@ def to_bytes(self) -> bytes: @classmethod def from_bytes(cls, buf: bytes) -> (Optional["Packet"], bytes): + # Trim null bytes that prefix a packet. + # This is (retroactively) why we start with stream ID 1! + while len(buf) > 0 and buf[0] == 0: + buf = buf[1:] + if len(buf) < Header.length(): return None, buf header = Header.from_bytes(buf[:Header.length()]) @@ -171,17 +176,30 @@ async def run_outbound(self, number: int, writer: StreamWriter): consumed = (buffer_len - len(buffer)) total_bytes += consumed if p is None: - continue - buffer = rem - if packet_count == 0: - assert p.start - packet_count += 1 - if not p.to_host: - # Ignore the packet - continue - writer.write(p.body) - await writer.drain() - if p.end: - break + if consumed > 0: + olog.debug( + f"consumed {consumed} padding zeros") + + # No packet to consume. Get more data. + rcvd = self.recv() # Has its own timeout, but isn't async. So: + await asyncio.sleep(0) + buffer += rcvd + else: + olog.debug(f"consumed {consumed} bytes") + + if packet_count == 0: + assert p.start + packet_count += 1 + if not p.to_host: + # Ignore the packet + continue + writer.write(p.body) + await writer.drain() + if p.end: + break + olog.info( + f"device closed outbound connection for client {number}") + + # TODO: Handle multiple streams. writer.close() await writer.wait_closed() diff --git a/not_tcp/not_tcp.py b/not_tcp/not_tcp.py index 15c032e..8a1cf27 100644 --- a/not_tcp/not_tcp.py +++ b/not_tcp/not_tcp.py @@ -15,6 +15,8 @@ import session from stream_utils import LimitForwarder +BUFFER_SIZE = 256 + class Flags(Struct): """ @@ -119,7 +121,7 @@ def elaborate(self, platform): m = Module() connected = self.connected input_buffer = m.submodules.input_buffer = SyncFIFOBuffered( - width=8, depth=256) + width=8, depth=BUFFER_SIZE) m.d.comb += [ self.stop.data.payload.eq(input_buffer.r_stream.payload), self.stop.data.valid.eq(input_buffer.r_stream.valid), @@ -236,9 +238,8 @@ def __init__(self, stream_id): def elaborate(self, platform): m = Module() - # Each of these is big enough to buffer one full packet. output_buffer = m.submodules.output_buffer = SyncFIFOBuffered( - width=8, depth=256) + width=8, depth=BUFFER_SIZE) m.d.comb += [ output_buffer.w_stream.payload.eq(self.stop.data.payload), @@ -254,23 +255,33 @@ def elaborate(self, platform): output_limiter.start.eq(0), self.bus.valid.eq(0) ] - connect(m, output_buffer.r_stream, output_limiter.inbound) - + m.d.comb += [ + self.bus.payload.eq(output_limiter.outbound.payload), + self.bus.valid.eq(output_limiter.outbound.valid), + output_limiter.outbound.ready.eq(self.bus.ready), + ] flags_layout = UnionLayout({"bytes": unsigned(8), "flags": Flags}) # Flags for outbound packet: send_flags = Signal(flags_layout) m.d.sync += send_flags.flags.to_host.eq(1) send_len = Signal(8) + # Pad up to 64 bytes of zeros, to ensure packet delivery. + pad_len = Signal(8) + + # Invariants: + # - End is set iff ~active and the buffer is empty. + # - We enter disconnected iff End is clear, + # i.e. End has been sent. # Cases in which we want to send a packet: with m.FSM(name="write"): with m.State("disconnected"): + m.d.comb += Assert(~send_flags.flags.end) m.next = "disconnected" with m.If(self.stop.active): # Immediately send a "start" packet. m.d.sync += send_flags.flags.start.eq(1) - m.d.sync += send_flags.flags.end.eq(0) m.d.sync += self.connected.eq(1) m.next = "write-stream" @@ -287,17 +298,27 @@ def elaborate(self, platform): # Lock in the level as the length of this packet. # We may send a short (zero-length) packet # to start or end the connection. - m.d.sync += send_len.eq(output_buffer.r_level) + m.d.sync += send_len.eq(output_buffer.level) # We send an explicit empty END packet. m.d.sync += send_flags.flags.end.eq( ~self.stop.active & - (output_buffer.r_level == Const(0))) + (output_buffer.level == Const(0))) with m.If(self.bus.ready): m.next = "write-len" with m.State("write-len"): m.next = "write-len" m.d.comb += self.bus.payload.eq(send_len) m.d.comb += self.bus.valid.eq(1) + + # Precompute padding on a cycle where we don't otherwise + # have much to do. + # In theory we only need to pad to 64... + # but that still doesn't get us the stop byte... + # so, double-padding? + m.d.sync += [ + pad_len.eq(128 - ((3 + send_len) % 64)) + ] + with m.If(self.bus.ready): m.next = "write-flags" with m.State("write-flags"): @@ -314,22 +335,35 @@ def elaborate(self, platform): m.next = "write-body" with m.State("write-body"): m.next = "write-body" + connect(m, output_buffer.r_stream, output_limiter.inbound) + + with m.If(output_limiter.done): + m.next = "zero-pad" + m.d.comb += [ + output_limiter.count.eq(pad_len), + output_limiter.start.eq(1), + output_limiter.inbound.payload.eq(0), + output_limiter.inbound.valid.eq(1), + ] + + with m.State("zero-pad"): + m.next = "zero-pad" m.d.comb += [ - self.bus.payload.eq(output_limiter.outbound.payload), - self.bus.valid.eq(output_limiter.outbound.valid), - output_limiter.outbound.ready.eq(self.bus.ready), + output_limiter.inbound.payload.eq(0), + output_limiter.inbound.valid.eq(1), ] with m.If(output_limiter.done): - m.d.sync += [ - send_flags.flags.start.eq(0), - send_flags.flags.end.eq(0), - ] with m.If(send_flags.flags.end): m.d.sync += self.connected.eq(0) m.next = "disconnected" with m.Else(): m.next = "write-stream" + # In either branch, we've sent a start packet. + m.d.sync += [ + send_flags.flags.start.eq(0), + send_flags.flags.end.eq(0), + ] return m diff --git a/not_tcp/not_tcp_test.py b/not_tcp/not_tcp_test.py index 5bfa88c..f57b6be 100644 --- a/not_tcp/not_tcp_test.py +++ b/not_tcp/not_tcp_test.py @@ -78,8 +78,9 @@ async def driver(ctx): while len(rcvd) > 0: # All data should be packetized. (p, remainder) = Packet.from_bytes(rcvd) - assert p is not None, f"remaining data: {rcvd}" - packets += [p] + assert (p is not None) or (len(remainder) == 0) + if p is not None: + packets += [p] rcvd = remainder bodies = bytes() for i in range(len(packets)): diff --git a/requirements.txt b/requirements.txt index 2177abc..8152297 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,8 @@ git+https://github.com/amaranth-lang/amaranth-soc@5c43cf58f15d9cd9c69ff83c979977 git+https://github.com/greatscottgadgets/luna@0.2.0 regex +pyserial + # Dev dependencies: flake8 pytest diff --git a/serial_server.py b/serial_server.py new file mode 100644 index 0000000..26f20b4 --- /dev/null +++ b/serial_server.py @@ -0,0 +1,74 @@ +import asyncio +import logging +import serial +import traceback + +from not_tcp.host import StreamProxy + + +log = logging.getLogger(__name__) + + +class HostSerial(StreamProxy): + """ + The real deal: serving to the Fomu. + """ + + def __init__(self, path): + """ + Arguments + -------- + path: path to the serial device + """ + self._path = path + self._conn = None + + def __enter__(self): + assert self._conn is None, "HostSerial is not reentrant" + self._conn = serial.Serial( + self._path, baudrate=9600, timeout=1, inter_byte_timeout=1) + assert self._conn.is_open, "HostSerial failed to open device" + + return self + + def __exit__(self, exc_type, exc_value, exe_traceback): + if exe_traceback is not None: + traceback.print_tb(exe_traceback) + + self._conn.close() + self._conn = None + + def send(self, b: bytes): + if self._conn is None or not self._conn.is_open: + log.error("HostSerial is not initialized") + return + self._conn.write(b) + self._conn.flush() + hex = b.hex(sep=' ') + log.debug(f"Wrote {len(b)} bytes to serial: ", hex) + + def recv(self) -> bytes: + if self._conn is None or not self._conn.is_open: + log.error("HostSerial is not initialized") + return bytes() + + # USB CDC is 64B, and we pad. + v = self._conn.read(64) + if v is None: + return bytes() + return v + + +async def amain(port): + # TODO: Scan for devices, set up or reset. + with HostSerial("/dev/ttyACM0") as srv: + server = await asyncio.start_server( + client_connected_cb=srv.client_connected, + host="localhost", + port=port + ) + log.info(f"listening on port {port}\n") + await server.serve_forever() + +if __name__ == "__main__": + asyncio.run(amain(3278)) diff --git a/stream_fixtures.py b/stream_fixtures.py index 9f0573f..6a3019e 100644 --- a/stream_fixtures.py +++ b/stream_fixtures.py @@ -81,7 +81,8 @@ def collect_queue(self, q: queue.Queue[bytes], batch_size: int = 100, stream = self._stream async def collector(ctx): - ctx.set(stream.ready, 1) + ready = self.is_ready() + ctx.set(stream.ready, ready) countup = 0 batch = bytes() @@ -89,9 +90,12 @@ async def collector(ctx): stream.valid, stream.payload): if rst_value or (not clk_edge): continue - if valid == 1: + if ready == 1 and valid == 1: # We just transferred a payload byte. batch += bytes([payload]) + ready = self.is_ready() + else: + ready = ready | self.is_ready() countup += 1 batch_exceeded = len(batch) >= batch_size