diff --git a/cmd/teslausb/main.go b/cmd/teslausb/main.go index 34412ff..dec7b34 100644 --- a/cmd/teslausb/main.go +++ b/cmd/teslausb/main.go @@ -14,6 +14,7 @@ import ( "github.com/teslausb-go/teslausb/internal/state" "github.com/teslausb-go/teslausb/internal/system" "github.com/teslausb-go/teslausb/internal/web" + "github.com/teslausb-go/teslausb/internal/wire" ) var version = "dev" @@ -72,7 +73,7 @@ func main() { go monitor.RunWiFiMonitor(ctx) // Create state machine - machine := state.New() + machine := state.New(wire.NewDeps()) // Start web server srv := web.NewServer(machine, version, *configPath) diff --git a/internal/archive/archive.go b/internal/archive/archive.go index 1b72fef..913605f 100644 --- a/internal/archive/archive.go +++ b/internal/archive/archive.go @@ -16,6 +16,8 @@ import ( const ArchiveMount = "/mnt/archive" +const tcpDialTimeout = 5 * time.Second + // IsReachable checks if the archive server is reachable via TCP. func IsReachable() bool { cfg := config.Get() @@ -32,7 +34,7 @@ func tcpReachable(host, port string) bool { if host == "" { return false } - conn, err := net.DialTimeout("tcp", host+":"+port, 5*time.Second) + conn, err := net.DialTimeout("tcp", host+":"+port, tcpDialTimeout) if err != nil { return false } diff --git a/internal/config/config.go b/internal/config/config.go index 0f292b2..87c4a88 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -79,10 +79,17 @@ func Load(path string) (*Config, error) { return &cfg, nil } +// Get returns a shallow copy of the current config. +// Safe because Config contains only value-type fields. +// If slice/map fields are added, switch to a deep copy. func Get() *Config { mu.RLock() defer mu.RUnlock() - return current + if current == nil { + return nil + } + cp := *current + return &cp } func Save(path string, cfg *Config) error { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 9f058cf..55ba8bd 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -65,3 +65,24 @@ func TestSaveAndReload(t *testing.T) { t.Errorf("expected 10.0.0.1, got %s", loaded.NFS.Server) } } + +func TestGetReturnsCopy(t *testing.T) { + // Load a config so current is set + tmpFile := filepath.Join(t.TempDir(), "test.yaml") + os.WriteFile(tmpFile, []byte("nfs:\n server: original\n"), 0644) + Load(tmpFile) + + cfg1 := Get() + if cfg1 == nil { + t.Fatal("expected non-nil config") + } + cfg1.NFS.Server = "mutated" + + cfg2 := Get() + if cfg2.NFS.Server == "mutated" { + t.Error("Get() should return a copy; mutation leaked through") + } + if cfg2.NFS.Server != "original" { + t.Errorf("expected 'original', got %q", cfg2.NFS.Server) + } +} diff --git a/internal/disk/disk.go b/internal/disk/disk.go index 8f56ca8..19dadc4 100644 --- a/internal/disk/disk.go +++ b/internal/disk/disk.go @@ -16,6 +16,12 @@ const ( MountPoint = "/mnt/cam" ) +const ( + diskReserveBytes = 500 * 1024 * 1024 // 500MB headroom for backing partition + minDiskSize = 1024 * 1024 * 1024 // 1GB minimum cam disk + truncatedClipThreshold = int64(100_000) // clips smaller than this are truncated +) + func Exists() bool { _, err := os.Stat(BackingFile) return err == nil @@ -34,9 +40,9 @@ func Create() error { return fmt.Errorf("statfs: %w", err) } available := int64(stat.Bavail) * int64(stat.Bsize) - reserve := int64(500 * 1024 * 1024) // 500MB headroom + reserve := int64(diskReserveBytes) size := available - reserve - if size < 1024*1024*1024 { // minimum 1GB + if size < minDiskSize { return fmt.Errorf("not enough space: %d bytes available", available) } @@ -157,7 +163,7 @@ func CleanArtifacts() { if err != nil || info.IsDir() { return nil } - if strings.HasSuffix(strings.ToLower(path), ".mp4") && info.Size() < 100_000 { + if strings.HasSuffix(strings.ToLower(path), ".mp4") && info.Size() < truncatedClipThreshold { os.Remove(path) log.Printf("cleaned truncated: %s (%d bytes)", filepath.Base(path), info.Size()) } diff --git a/internal/gadget/idle.go b/internal/gadget/idle.go index fca90ea..d13be09 100644 --- a/internal/gadget/idle.go +++ b/internal/gadget/idle.go @@ -10,6 +10,12 @@ import ( "time" ) +const ( + idleThresholdBytes = int64(500_000) // bytes/sec below which USB is "idle" + idleConsecutiveRequired = 5 // consecutive idle checks needed + idleTimeoutSeconds = 90 // max seconds to wait for idle +) + func findMassStoragePID() (int, error) { entries, _ := os.ReadDir("/proc") for _, e := range entries { @@ -53,10 +59,10 @@ func WaitForIdle() error { prevBytes := int64(-1) idleCount := 0 - threshold := int64(500_000) + threshold := idleThresholdBytes log.Println("waiting for USB write idle...") - for i := 0; i < 90; i++ { + for i := 0; i < idleTimeoutSeconds; i++ { time.Sleep(1 * time.Second) written, err := readWriteBytes(pid) if err != nil { @@ -71,7 +77,7 @@ func WaitForIdle() error { if delta < threshold { idleCount++ - if idleCount >= 5 { + if idleCount >= idleConsecutiveRequired { log.Println("USB write idle detected") return nil } @@ -79,5 +85,5 @@ func WaitForIdle() error { idleCount = 0 } } - return fmt.Errorf("timeout waiting for USB idle after 90 seconds") + return fmt.Errorf("timeout waiting for USB idle after %d seconds", idleTimeoutSeconds) } diff --git a/internal/state/deps.go b/internal/state/deps.go new file mode 100644 index 0000000..bf0aa3a --- /dev/null +++ b/internal/state/deps.go @@ -0,0 +1,59 @@ +package state + +import ( + "context" + + "github.com/teslausb-go/teslausb/internal/webhook" +) + +// DiskOps abstracts disk image operations. +type DiskOps interface { + Exists() bool + Create() error + Mount() error + Unmount() error + CleanArtifacts() + BackingFilePath() string +} + +// GadgetOps abstracts USB gadget lifecycle. +type GadgetOps interface { + Enable(backingFile string) error + Disable() error + WaitForIdle() error +} + +// ArchiveOps abstracts clip archiving operations. +type ArchiveOps interface { + IsReachable() bool + MountArchive() error + UnmountArchive() + ArchiveClips(ctx context.Context) (clips int, bytes int64, err error) + ManageFreeSpace() +} + +// SystemOps abstracts system-level operations (LED, time sync). +type SystemOps interface { + SetLED(mode string) + SyncTime() +} + +// KeepAwaker abstracts vehicle keep-awake signaling. +type KeepAwaker interface { + Send(ctx context.Context, command string) +} + +// Notifier abstracts event notifications. +type Notifier interface { + Send(ctx context.Context, event webhook.Event) +} + +// Deps bundles all external dependencies for the state machine. +type Deps struct { + Disk DiskOps + Gadget GadgetOps + Archive ArchiveOps + System SystemOps + KeepAwake KeepAwaker + Notify Notifier +} diff --git a/internal/state/machine.go b/internal/state/machine.go index abe9916..0cea6d3 100644 --- a/internal/state/machine.go +++ b/internal/state/machine.go @@ -10,13 +10,6 @@ import ( "sync" "time" - "github.com/teslausb-go/teslausb/internal/archive" - "github.com/teslausb-go/teslausb/internal/ble" - "github.com/teslausb-go/teslausb/internal/config" - "github.com/teslausb-go/teslausb/internal/disk" - "github.com/teslausb-go/teslausb/internal/gadget" - "github.com/teslausb-go/teslausb/internal/notify" - "github.com/teslausb-go/teslausb/internal/system" "github.com/teslausb-go/teslausb/internal/webhook" ) @@ -40,6 +33,7 @@ type CumulativeStats struct { type Machine struct { mu sync.RWMutex + deps Deps state State lastArchive time.Time lastError string @@ -53,8 +47,14 @@ type Machine struct { const lastArchiveFile = "/mutable/teslausb/last_archive" const statsFile = "/mutable/teslausb/stats.json" -func New() *Machine { - m := &Machine{state: StateBooting} +const ( + pollInterval = 30 * time.Second + networkStabilizeDelay = 20 * time.Second + keepAliveInterval = 5 * time.Minute +) + +func New(deps Deps) *Machine { + m := &Machine{state: StateBooting, deps: deps} // Restore last archive timestamp if data, err := os.ReadFile(lastArchiveFile); err == nil { if t, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data))); err == nil { @@ -119,18 +119,30 @@ func (m *Machine) setState(s State) { } } +// tryEnableGadget attempts to enable the USB gadget if not already enabled. +// Returns true if the gadget was just enabled (for callers that want to notify). +func (m *Machine) tryEnableGadget() bool { + if m.gadgetEnabled { + return false + } + if err := m.deps.Gadget.Enable(m.deps.Disk.BackingFilePath()); err != nil { + return false + } + m.gadgetEnabled = true + log.Println("USB gadget enabled (delayed)") + return true +} + // Run starts the main state machine loop. func (m *Machine) Run(ctx context.Context) error { - // First-run: create disk image if needed - if !disk.Exists() { + if !m.deps.Disk.Exists() { log.Println("first run: creating cam disk image...") - if err := disk.Create(); err != nil { + if err := m.deps.Disk.Create(); err != nil { return fmt.Errorf("create disk: %w", err) } } - // Enable USB gadget (non-fatal — web UI should work even without UDC) - if err := gadget.Enable(disk.BackingFile); err != nil { + if err := m.deps.Gadget.Enable(m.deps.Disk.BackingFilePath()); err != nil { log.Printf("warning: %v (web UI still available, gadget will retry)", err) m.mu.Lock() m.lastError = err.Error() @@ -140,12 +152,12 @@ func (m *Machine) Run(ctx context.Context) error { } m.setState(StateAway) - system.SetLED("slowblink") + m.deps.System.SetLED("slowblink") for { select { case <-ctx.Done(): - gadget.Disable() + m.deps.Gadget.Disable() return nil default: } @@ -164,7 +176,7 @@ func (m *Machine) Run(ctx context.Context) error { } func (m *Machine) runAway(ctx context.Context) { - ticker := time.NewTicker(30 * time.Second) + ticker := time.NewTicker(pollInterval) defer ticker.Stop() for { @@ -172,14 +184,8 @@ func (m *Machine) runAway(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - // Retry gadget enable if it failed (e.g. UDC wasn't available at boot) - if !m.gadgetEnabled { - if err := gadget.Enable(disk.BackingFile); err == nil { - m.gadgetEnabled = true - log.Println("USB gadget enabled (delayed)") - } - } - if archive.IsReachable() { + m.tryEnableGadget() + if m.deps.Archive.IsReachable() { m.setState(StateArriving) return } @@ -188,18 +194,18 @@ func (m *Machine) runAway(ctx context.Context) { } func (m *Machine) runArriving(ctx context.Context) { - system.SetLED("fastblink") + m.deps.System.SetLED("fastblink") - log.Println("archive server reachable, waiting 20s for network to stabilize...") - time.Sleep(20 * time.Second) + log.Println("archive server reachable, waiting for network to stabilize...") + time.Sleep(networkStabilizeDelay) - system.SyncTime() + m.deps.System.SyncTime() - if err := gadget.WaitForIdle(); err != nil { + if err := m.deps.Gadget.WaitForIdle(); err != nil { log.Printf("wait for idle: %v", err) } - if err := gadget.Disable(); err != nil { + if err := m.deps.Gadget.Disable(); err != nil { log.Printf("disable gadget: %v", err) m.gadgetEnabled = false m.setState(StateAway) @@ -207,21 +213,21 @@ func (m *Machine) runArriving(ctx context.Context) { } m.gadgetEnabled = false - notify.Send(ctx, webhook.Event{Event: "usb_disconnected", Message: "USB gadget disabled for archiving"}) + m.deps.Notify.Send(ctx, webhook.Event{Event: "usb_disconnected", Message: "USB gadget disabled for archiving"}) - if err := disk.Mount(); err != nil { + if err := m.deps.Disk.Mount(); err != nil { log.Printf("mount cam: %v", err) - gadget.Enable(disk.BackingFile) + m.deps.Gadget.Enable(m.deps.Disk.BackingFilePath()) m.setState(StateAway) return } - disk.CleanArtifacts() + m.deps.Disk.CleanArtifacts() - if err := archive.MountArchive(); err != nil { + if err := m.deps.Archive.MountArchive(); err != nil { log.Printf("mount archive: %v", err) - disk.Unmount() - gadget.Enable(disk.BackingFile) + m.deps.Disk.Unmount() + m.deps.Gadget.Enable(m.deps.Disk.BackingFilePath()) m.setState(StateAway) return } @@ -229,60 +235,69 @@ func (m *Machine) runArriving(ctx context.Context) { m.setState(StateArchiving) } -func (m *Machine) runArchiving(ctx context.Context) { - cfg := config.Get() - - m.sendKeepAwake(ctx, cfg, "start") - - keepAliveCtx, keepAliveCancel := context.WithCancel(ctx) +// startKeepAlive begins periodic keep-awake nudges and returns a cancel function. +func (m *Machine) startKeepAlive(ctx context.Context) context.CancelFunc { + m.deps.KeepAwake.Send(ctx, "start") + keepAliveCtx, cancel := context.WithCancel(ctx) go func() { - ticker := time.NewTicker(5 * time.Minute) + ticker := time.NewTicker(keepAliveInterval) defer ticker.Stop() for { select { case <-keepAliveCtx.Done(): return case <-ticker.C: - m.sendKeepAwake(keepAliveCtx, cfg, "nudge") + m.deps.KeepAwake.Send(keepAliveCtx, "nudge") } } }() + return cancel +} + +// updateAndPersistStats records archive results in memory and writes to disk. +func (m *Machine) updateAndPersistStats(clips int, bytes int64) { + now := time.Now() + m.mu.Lock() + m.lastArchive = now + m.archiveClips = clips + m.archiveBytes = bytes + m.cumulative.TotalClips += clips + m.cumulative.TotalBytes += bytes + m.cumulative.ArchiveCount++ + m.cumulative.LastArchive = now + cumSnapshot := m.cumulative + m.mu.Unlock() + + os.WriteFile(lastArchiveFile, []byte(now.Format(time.RFC3339)), 0644) + if data, err := json.Marshal(cumSnapshot); err == nil { + if err := os.WriteFile(statsFile, data, 0644); err != nil { + log.Printf("save stats: %v", err) + } + } +} - notify.Send(ctx, webhook.Event{Event: "archive_started", Message: "Archiving dashcam clips"}) +func (m *Machine) runArchiving(ctx context.Context) { + stopKeepAlive := m.startKeepAlive(ctx) + + m.deps.Notify.Send(ctx, webhook.Event{Event: "archive_started", Message: "Archiving dashcam clips"}) start := time.Now() - clips, bytes, err := archive.ArchiveClips(ctx) + clips, bytes, err := m.deps.Archive.ArchiveClips(ctx) duration := time.Since(start) - keepAliveCancel() + stopKeepAlive() if err != nil { m.mu.Lock() m.lastError = err.Error() m.mu.Unlock() log.Printf("archive error: %v", err) - notify.Send(ctx, webhook.Event{ + m.deps.Notify.Send(ctx, webhook.Event{ Event: "archive_error", Message: err.Error(), }) } else { - now := time.Now() - m.mu.Lock() - m.lastArchive = now - m.archiveClips = clips - m.archiveBytes = bytes - m.cumulative.TotalClips += clips - m.cumulative.TotalBytes += bytes - m.cumulative.ArchiveCount++ - m.cumulative.LastArchive = now - cumSnapshot := m.cumulative - m.mu.Unlock() - os.WriteFile(lastArchiveFile, []byte(now.Format(time.RFC3339)), 0644) - if statsData, err := json.Marshal(cumSnapshot); err == nil { - if err := os.WriteFile(statsFile, statsData, 0644); err != nil { - log.Printf("save stats: %v", err) - } - } - notify.Send(ctx, webhook.Event{ + m.updateAndPersistStats(clips, bytes) + m.deps.Notify.Send(ctx, webhook.Event{ Event: "archive_complete", Message: fmt.Sprintf("Archived %d clips in %s", clips, duration.Round(time.Second)), Data: map[string]any{ @@ -293,28 +308,26 @@ func (m *Machine) runArchiving(ctx context.Context) { }) } - archive.ManageFreeSpace() + m.deps.Archive.ManageFreeSpace() m.setState(StateIdle) } func (m *Machine) runIdle(ctx context.Context) { - system.SetLED("heartbeat") - - cfg := config.Get() - m.sendKeepAwake(ctx, cfg, "stop") + m.deps.System.SetLED("heartbeat") + m.deps.KeepAwake.Send(ctx, "stop") - archive.UnmountArchive() - disk.Unmount() + m.deps.Archive.UnmountArchive() + m.deps.Disk.Unmount() - if err := gadget.Enable(disk.BackingFile); err != nil { + if err := m.deps.Gadget.Enable(m.deps.Disk.BackingFilePath()); err != nil { log.Printf("warning: gadget re-enable failed: %v", err) m.gadgetEnabled = false } else { m.gadgetEnabled = true - notify.Send(ctx, webhook.Event{Event: "usb_connected", Message: "USB gadget re-enabled"}) + m.deps.Notify.Send(ctx, webhook.Event{Event: "usb_connected", Message: "USB gadget re-enabled"}) } - ticker := time.NewTicker(30 * time.Second) + ticker := time.NewTicker(pollInterval) defer ticker.Stop() for { @@ -322,41 +335,15 @@ func (m *Machine) runIdle(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - // Retry gadget if it failed - if !m.gadgetEnabled { - if err := gadget.Enable(disk.BackingFile); err == nil { - m.gadgetEnabled = true - log.Println("USB gadget enabled (delayed)") - notify.Send(ctx, webhook.Event{Event: "usb_connected", Message: "USB gadget re-enabled"}) - } + if m.tryEnableGadget() { + m.deps.Notify.Send(ctx, webhook.Event{Event: "usb_connected", Message: "USB gadget re-enabled"}) } - if !archive.IsReachable() { + if !m.deps.Archive.IsReachable() { log.Println("archive server unreachable — user left home") m.setState(StateAway) - system.SetLED("slowblink") + m.deps.System.SetLED("slowblink") return } } } } - -func (m *Machine) sendKeepAwake(ctx context.Context, cfg *config.Config, command string) { - if cfg == nil { - return - } - switch cfg.KeepAwake.Method { - case "ble": - if cfg.KeepAwake.VIN != "" { - if command == "stop" { - ble.SentryOff(cfg.KeepAwake.VIN) - } else { - ble.KeepAwake(cfg.KeepAwake.VIN) - } - } - case "webhook": - if cfg.KeepAwake.WebhookURL != "" { - // Send flat {"awake_command":"..."} matching original teslausb format - webhook.SendRaw(ctx, cfg.KeepAwake.WebhookURL, map[string]string{"awake_command": command}) - } - } -} diff --git a/internal/state/machine_test.go b/internal/state/machine_test.go index f3068c2..cc536fe 100644 --- a/internal/state/machine_test.go +++ b/internal/state/machine_test.go @@ -1,18 +1,72 @@ package state import ( + "context" + "sync/atomic" "testing" + "time" + + "github.com/teslausb-go/teslausb/internal/webhook" ) +type noopDisk struct{} + +func (d *noopDisk) Exists() bool { return true } +func (d *noopDisk) Create() error { return nil } +func (d *noopDisk) Mount() error { return nil } +func (d *noopDisk) Unmount() error { return nil } +func (d *noopDisk) CleanArtifacts() {} +func (d *noopDisk) BackingFilePath() string { return "/tmp/test.bin" } + +type noopGadget struct{} + +func (g *noopGadget) Enable(string) error { return nil } +func (g *noopGadget) Disable() error { return nil } +func (g *noopGadget) WaitForIdle() error { return nil } + +type noopArchive struct { + reachable atomic.Bool +} + +func (a *noopArchive) IsReachable() bool { return a.reachable.Load() } +func (a *noopArchive) MountArchive() error { return nil } +func (a *noopArchive) UnmountArchive() {} +func (a *noopArchive) ArchiveClips(context.Context) (int, int64, error) { return 0, 0, nil } +func (a *noopArchive) ManageFreeSpace() {} + +type noopSystem struct{ lastLED string } + +func (s *noopSystem) SetLED(mode string) { s.lastLED = mode } +func (s *noopSystem) SyncTime() {} + +type noopKeepAwake struct{ lastCmd string } + +func (k *noopKeepAwake) Send(_ context.Context, cmd string) { k.lastCmd = cmd } + +type noopNotify struct{ events []webhook.Event } + +func (n *noopNotify) Send(_ context.Context, e webhook.Event) { n.events = append(n.events, e) } + +func testDeps() Deps { + return Deps{ + Disk: &noopDisk{}, + Gadget: &noopGadget{}, + Archive: &noopArchive{}, + System: &noopSystem{}, + KeepAwake: &noopKeepAwake{}, + Notify: &noopNotify{}, + } +} + func TestNewMachine(t *testing.T) { - m := New() + m := New(testDeps()) if m.State() != StateBooting { t.Errorf("expected booting, got %s", m.State()) } } func TestStateTransition(t *testing.T) { - m := New() + m := New(testDeps()) var received State m.OnStateChange(func(s State) { received = s }) m.setState(StateAway) @@ -22,9 +76,112 @@ func TestStateTransition(t *testing.T) { } func TestInfo(t *testing.T) { - m := New() + m := New(testDeps()) info := m.Info() if info["state"] != "booting" { t.Errorf("expected booting, got %s", info["state"]) } } + +func TestTriggerArchive(t *testing.T) { + m := New(testDeps()) + // Not idle — should return false + if m.TriggerArchive() { + t.Error("should not trigger from booting state") + } + // Set to idle + m.setState(StateIdle) + if !m.TriggerArchive() { + t.Error("should trigger from idle state") + } + if m.State() != StateArriving { + t.Errorf("expected arriving, got %s", m.State()) + } +} + +func TestTryEnableGadget(t *testing.T) { + g := &noopGadget{} + deps := testDeps() + deps.Gadget = g + m := New(deps) + + // Already enabled — should return false + m.gadgetEnabled = true + if m.tryEnableGadget() { + t.Error("should return false when already enabled") + } + + // Not enabled — should return true + m.gadgetEnabled = false + if !m.tryEnableGadget() { + t.Error("should return true when enable succeeds") + } + if !m.gadgetEnabled { + t.Error("gadgetEnabled should be true after successful enable") + } +} + +func TestRunAwayTransitionsOnReachable(t *testing.T) { + archive := &noopArchive{} + deps := testDeps() + deps.Archive = archive + m := New(deps) + m.setState(StateAway) + + ctx, cancel := context.WithCancel(context.Background()) + + // Start runAway in goroutine + done := make(chan struct{}) + go func() { + m.runAway(ctx) + close(done) + }() + + // Simulate archive becoming reachable (thread-safe via atomic.Bool) + time.Sleep(100 * time.Millisecond) + archive.reachable.Store(true) + + // Wait for state change (with timeout) + select { + case <-done: + case <-time.After(pollInterval + 5*time.Second): + cancel() + t.Fatal("runAway did not return after archive became reachable") + } + cancel() + + if m.State() != StateArriving { + t.Errorf("expected arriving, got %s", m.State()) + } +} + +func TestRunIdleNotifiesOnGadgetEnable(t *testing.T) { + n := &noopNotify{} + archive := &noopArchive{} + archive.reachable.Store(true) + deps := testDeps() + deps.Notify = n + deps.Archive = archive + m := New(deps) + m.setState(StateIdle) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + m.runIdle(ctx) + }() + + // Wait for initial gadget enable + notification + time.Sleep(100 * time.Millisecond) + cancel() + + found := false + for _, e := range n.events { + if e.Event == "usb_connected" { + found = true + } + } + if !found { + t.Error("expected usb_connected notification") + } +} diff --git a/internal/web/mounttest.go b/internal/web/mounttest.go new file mode 100644 index 0000000..687614d --- /dev/null +++ b/internal/web/mounttest.go @@ -0,0 +1,59 @@ +package web + +import ( + "fmt" + "net" + "net/http" + "os" + "os/exec" + "strings" + "time" +) + +const mountTestTimeout = 5 * time.Second + +type mountTest struct { + host string + port string + source string // e.g. "server:/share" or "//server/share" + testDir string // temp mount point + mount func(dir string) ([]byte, error) // execute the mount command + cleanup func() // optional extra cleanup (e.g. remove credentials file) +} + +func (s *Server) runMountTest(w http.ResponseWriter, mt mountTest) { + // TCP connectivity check + conn, err := net.DialTimeout("tcp", mt.host+":"+mt.port, mountTestTimeout) + if err != nil { + jsonResponse(w, map[string]any{ + "ok": false, + "error": fmt.Sprintf("Cannot reach %s:%s — %v", mt.host, mt.port, err), + }) + return + } + conn.Close() + + // Temp mount + os.MkdirAll(mt.testDir, 0755) + defer func() { + exec.Command("umount", "-f", "-l", mt.testDir).Run() + os.Remove(mt.testDir) + if mt.cleanup != nil { + mt.cleanup() + } + }() + + out, err := mt.mount(mt.testDir) + if err != nil { + jsonResponse(w, map[string]any{ + "ok": false, + "error": fmt.Sprintf("Mount failed: %s", strings.TrimSpace(string(out))), + }) + return + } + + jsonResponse(w, map[string]any{ + "ok": true, + "message": fmt.Sprintf("Successfully mounted %s", mt.source), + }) +} diff --git a/internal/web/pathutil.go b/internal/web/pathutil.go new file mode 100644 index 0000000..23165f4 --- /dev/null +++ b/internal/web/pathutil.go @@ -0,0 +1,22 @@ +package web + +import ( + "fmt" + "path/filepath" + "strings" +) + +// safePath validates a user-supplied path against a base directory. +// Returns the full resolved path if it stays within base, or an error +// if the path would escape. A naive join of base + userPath is cleaned +// and checked; if the result would leave base the call is rejected. +func safePath(base, userPath string) (string, error) { + // Clean the raw join first — this preserves ".." traversals so we + // can detect them, unlike the "/" prefix trick which silently strips them. + full := filepath.Clean(filepath.Join(base, userPath)) + rel, err := filepath.Rel(base, full) + if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("path traversal blocked: %q", userPath) + } + return full, nil +} diff --git a/internal/web/pathutil_test.go b/internal/web/pathutil_test.go new file mode 100644 index 0000000..704f81d --- /dev/null +++ b/internal/web/pathutil_test.go @@ -0,0 +1,44 @@ +package web + +import ( + "testing" +) + +func TestSafePath(t *testing.T) { + base := "/mnt/cam" + + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {"valid subdir", "TeslaCam/SavedClips", "/mnt/cam/TeslaCam/SavedClips", false}, + {"valid file", "TeslaCam/SavedClips/clip.mp4", "/mnt/cam/TeslaCam/SavedClips/clip.mp4", false}, + {"empty path returns base", "", "/mnt/cam", false}, + {"dot returns base", ".", "/mnt/cam", false}, + {"traversal blocked", "../../etc/passwd", "", true}, + {"nested traversal blocked", "TeslaCam/../../../etc/passwd", "", true}, + {"double dot in name OK", "TeslaCam/clip..v2.mp4", "/mnt/cam/TeslaCam/clip..v2.mp4", false}, + {"absolute path confined", "/etc/passwd", "/mnt/cam/etc/passwd", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := safePath(base, tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("safePath(%q, %q) = %q, want error", base, tt.input, got) + } + return + } + if err != nil { + t.Errorf("safePath(%q, %q) error: %v", base, tt.input, err) + return + } + if got != tt.want { + t.Errorf("safePath(%q, %q) = %q, want %q", base, tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/web/server.go b/internal/web/server.go index 7f5ae2d..33c055e 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -5,14 +5,12 @@ import ( "fmt" "io/fs" "log" - "net" "net/http" "os" "os/exec" "path/filepath" "strings" "syscall" - "time" "github.com/teslausb-go/teslausb/internal/ble" "github.com/teslausb-go/teslausb/internal/config" @@ -113,12 +111,12 @@ func (s *Server) handleListFiles(w http.ResponseWriter, r *http.Request) { if reqPath == "" { reqPath = "TeslaCam" } - reqPath = filepath.Clean(reqPath) - if strings.Contains(reqPath, "..") { + fullPath, err := safePath(disk.MountPoint, reqPath) + if err != nil { http.Error(w, "invalid path", 400) return } - fullPath := filepath.Join(disk.MountPoint, reqPath) + reqPath = filepath.Clean(reqPath) // normalize for JSON response paths entries, err := os.ReadDir(fullPath) if err != nil { jsonResponse(w, []any{}) @@ -149,12 +147,11 @@ func (s *Server) handleListFiles(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleDownloadFile(w http.ResponseWriter, r *http.Request) { - reqPath := filepath.Clean(r.URL.Query().Get("path")) - if strings.Contains(reqPath, "..") { + fullPath, err := safePath(disk.MountPoint, r.URL.Query().Get("path")) + if err != nil { http.Error(w, "invalid path", 400) return } - fullPath := filepath.Join(disk.MountPoint, reqPath) http.ServeFile(w, r, fullPath) } @@ -163,12 +160,11 @@ func (s *Server) handleDeleteFile(w http.ResponseWriter, r *http.Request) { Path string `json:"path"` } json.NewDecoder(r.Body).Decode(&req) - req.Path = filepath.Clean(req.Path) - if strings.Contains(req.Path, "..") { + fullPath, err := safePath(disk.MountPoint, req.Path) + if err != nil { http.Error(w, "invalid path", 400) return } - fullPath := filepath.Join(disk.MountPoint, req.Path) if err := os.RemoveAll(fullPath); err != nil { http.Error(w, err.Error(), 500) return @@ -208,30 +204,18 @@ func (s *Server) handleTestNFS(w http.ResponseWriter, r *http.Request) { return } - // Test TCP connectivity to NFS port - conn, err := net.DialTimeout("tcp", req.Server+":2049", 5*time.Second) - if err != nil { - jsonResponse(w, map[string]any{"ok": false, "error": fmt.Sprintf("Cannot reach %s:2049 — %v", req.Server, err)}) - return - } - conn.Close() - - // Try a temporary mount - testDir := "/tmp/nfs-test" - os.MkdirAll(testDir, 0755) - defer func() { - exec.Command("umount", "-f", "-l", testDir).Run() - os.Remove(testDir) - }() - source := fmt.Sprintf("%s:%s", req.Server, req.Share) - out, err := exec.Command("mount", "-t", "nfs", source, testDir, "-o", "ro,nolock,proto=tcp,vers=3,timeo=10,retrans=1").CombinedOutput() - if err != nil { - jsonResponse(w, map[string]any{"ok": false, "error": fmt.Sprintf("Mount failed: %s", strings.TrimSpace(string(out)))}) - return - } - - jsonResponse(w, map[string]any{"ok": true, "message": fmt.Sprintf("Successfully mounted %s", source)}) + s.runMountTest(w, mountTest{ + host: req.Server, + port: "2049", + source: source, + testDir: "/tmp/nfs-test", + mount: func(dir string) ([]byte, error) { + cmd := exec.Command("mount", "-t", "nfs", source, dir, + "-o", "ro,nolock,proto=tcp,vers=3,timeo=10,retrans=1") + return cmd.CombinedOutput() + }, + }) } func (s *Server) handleTestCIFS(w http.ResponseWriter, r *http.Request) { @@ -247,34 +231,22 @@ func (s *Server) handleTestCIFS(w http.ResponseWriter, r *http.Request) { return } - // Test TCP connectivity to SMB port - conn, err := net.DialTimeout("tcp", req.Server+":445", 5*time.Second) - if err != nil { - jsonResponse(w, map[string]any{"ok": false, "error": fmt.Sprintf("Cannot reach %s:445 — %v", req.Server, err)}) - return - } - conn.Close() - - // Try a temporary mount - testDir := "/tmp/cifs-test" - os.MkdirAll(testDir, 0755) - defer func() { - exec.Command("umount", "-f", "-l", testDir).Run() - os.Remove(testDir) - }() - source := fmt.Sprintf("//%s/%s", req.Server, req.Share) credFile := "/tmp/.cifs-test-credentials" - os.WriteFile(credFile, []byte(fmt.Sprintf("username=%s\npassword=%s\n", req.Username, req.Password)), 0600) - defer os.Remove(credFile) - opts := fmt.Sprintf("credentials=%s,vers=3.0", credFile) - out, err := exec.Command("mount", "-t", "cifs", source, testDir, "-o", opts).CombinedOutput() - if err != nil { - jsonResponse(w, map[string]any{"ok": false, "error": fmt.Sprintf("Mount failed: %s", strings.TrimSpace(string(out)))}) - return - } - jsonResponse(w, map[string]any{"ok": true, "message": fmt.Sprintf("Successfully mounted %s", source)}) + s.runMountTest(w, mountTest{ + host: req.Server, + port: "445", + source: source, + testDir: "/tmp/cifs-test", + mount: func(dir string) ([]byte, error) { + os.WriteFile(credFile, []byte(fmt.Sprintf("username=%s\npassword=%s\n", req.Username, req.Password)), 0600) + opts := fmt.Sprintf("credentials=%s,vers=3.0", credFile) + cmd := exec.Command("mount", "-t", "cifs", source, dir, "-o", opts) + return cmd.CombinedOutput() + }, + cleanup: func() { os.Remove(credFile) }, + }) } func (s *Server) handleTriggerArchive(w http.ResponseWriter, r *http.Request) { diff --git a/internal/web/server_test.go b/internal/web/server_test.go index c84baa5..c7e16a1 100644 --- a/internal/web/server_test.go +++ b/internal/web/server_test.go @@ -1,16 +1,65 @@ package web import ( + "context" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/teslausb-go/teslausb/internal/state" + "github.com/teslausb-go/teslausb/internal/webhook" ) +func webTestDeps() state.Deps { + return state.Deps{ + Disk: &tdisk{}, + Gadget: &tgadget{}, + Archive: &tarchive{}, + System: &tsystem{}, + KeepAwake: &tkeep{}, + Notify: &tnotify{}, + } +} + +type tdisk struct{} + +func (d *tdisk) Exists() bool { return true } +func (d *tdisk) Create() error { return nil } +func (d *tdisk) Mount() error { return nil } +func (d *tdisk) Unmount() error { return nil } +func (d *tdisk) CleanArtifacts() {} +func (d *tdisk) BackingFilePath() string { return "/tmp/test.bin" } + +type tgadget struct{} + +func (g *tgadget) Enable(string) error { return nil } +func (g *tgadget) Disable() error { return nil } +func (g *tgadget) WaitForIdle() error { return nil } + +type tarchive struct{} + +func (a *tarchive) IsReachable() bool { return false } +func (a *tarchive) MountArchive() error { return nil } +func (a *tarchive) UnmountArchive() {} +func (a *tarchive) ArchiveClips(context.Context) (int, int64, error) { return 0, 0, nil } +func (a *tarchive) ManageFreeSpace() {} + +type tsystem struct{} + +func (s *tsystem) SetLED(string) {} +func (s *tsystem) SyncTime() {} + +type tkeep struct{} + +func (k *tkeep) Send(context.Context, string) {} + +type tnotify struct{} + +func (n *tnotify) Send(context.Context, webhook.Event) {} + func TestStatusEndpoint(t *testing.T) { - m := state.New() + m := state.New(webTestDeps()) s := NewServer(m, "test-version", "/tmp/test-config.yaml") req := httptest.NewRequest("GET", "/api/status", nil) @@ -28,7 +77,7 @@ func TestStatusEndpoint(t *testing.T) { } func TestGetConfigEndpoint(t *testing.T) { - m := state.New() + m := state.New(webTestDeps()) s := NewServer(m, "test", "/tmp/test.yaml") req := httptest.NewRequest("GET", "/api/config", nil) @@ -41,7 +90,7 @@ func TestGetConfigEndpoint(t *testing.T) { } func TestBLEStatusEndpoint(t *testing.T) { - m := state.New() + m := state.New(webTestDeps()) s := NewServer(m, "test", "/tmp/test.yaml") req := httptest.NewRequest("GET", "/api/ble/status", nil) diff --git a/internal/web/ws.go b/internal/web/ws.go index 5c7f1bc..9d2774e 100644 --- a/internal/web/ws.go +++ b/internal/web/ws.go @@ -2,7 +2,6 @@ package web import ( "encoding/json" - "log" "net/http" "sync" @@ -10,7 +9,7 @@ import ( ) type Hub struct { - mu sync.RWMutex + mu sync.Mutex clients map[*websocket.Conn]bool } @@ -46,12 +45,17 @@ func (h *Hub) Broadcast(data any) { return } - h.mu.RLock() - defer h.mu.RUnlock() + h.mu.Lock() + defer h.mu.Unlock() + var dead []*websocket.Conn for ws := range h.clients { if _, err := ws.Write(msg); err != nil { - log.Printf("ws write error: %v", err) + dead = append(dead, ws) } } + for _, ws := range dead { + delete(h.clients, ws) + ws.Close() + } } diff --git a/internal/web/ws_test.go b/internal/web/ws_test.go new file mode 100644 index 0000000..2819468 --- /dev/null +++ b/internal/web/ws_test.go @@ -0,0 +1,95 @@ +package web + +import ( + "encoding/json" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +func TestHubNewHub(t *testing.T) { + h := NewHub() + if h.clients == nil { + t.Error("clients map should be initialized") + } +} + +func TestHubBroadcastEmptyHub(t *testing.T) { + h := NewHub() + // Should not panic on empty hub + h.Broadcast(map[string]string{"test": "data"}) + + h.mu.Lock() + count := len(h.clients) + h.mu.Unlock() + + if count != 0 { + t.Errorf("expected 0 clients, got %d", count) + } +} + +func TestHubBroadcastRemovesDeadClients(t *testing.T) { + h := NewHub() + + // Create a real WebSocket server + client using httptest + srv := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { + h.mu.Lock() + h.clients[ws] = true + h.mu.Unlock() + + // Read until error (keeps handler alive) + buf := make([]byte, 1024) + for { + if _, err := ws.Read(buf); err != nil { + return + } + } + })) + defer srv.Close() + + // Connect a client, then close it to simulate a dead connection + wsURL := "ws" + srv.URL[4:] // http -> ws + ws, err := websocket.Dial(wsURL, "", srv.URL) + if err != nil { + t.Fatalf("dial: %v", err) + } + + // Wait for the handler to register the client + for i := 0; i < 100; i++ { + h.mu.Lock() + n := len(h.clients) + h.mu.Unlock() + if n > 0 { + break + } + time.Sleep(time.Millisecond) + } + + // Close the client connection — next write will fail + ws.Close() + + // Allow TCP close to propagate to the server side + time.Sleep(50 * time.Millisecond) + + // Broadcast should detect the dead connection and remove it + h.Broadcast(map[string]string{"test": "data"}) + + h.mu.Lock() + remaining := len(h.clients) + h.mu.Unlock() + + if remaining != 0 { + t.Errorf("expected dead client to be removed, got %d remaining", remaining) + } +} + +// Ensure json.Marshal works for Broadcast payloads +func TestBroadcastPayloadMarshal(t *testing.T) { + data := map[string]any{"type": "state", "state": "idle"} + _, err := json.Marshal(data) + if err != nil { + t.Errorf("expected payload to marshal: %v", err) + } +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go new file mode 100644 index 0000000..72483d9 --- /dev/null +++ b/internal/wire/wire.go @@ -0,0 +1,90 @@ +package wire + +import ( + "context" + + "github.com/teslausb-go/teslausb/internal/archive" + "github.com/teslausb-go/teslausb/internal/ble" + "github.com/teslausb-go/teslausb/internal/config" + "github.com/teslausb-go/teslausb/internal/disk" + "github.com/teslausb-go/teslausb/internal/gadget" + "github.com/teslausb-go/teslausb/internal/notify" + "github.com/teslausb-go/teslausb/internal/state" + "github.com/teslausb-go/teslausb/internal/system" + "github.com/teslausb-go/teslausb/internal/webhook" +) + +// DiskAdapter wraps the disk package functions. +type DiskAdapter struct{} + +func (d *DiskAdapter) Exists() bool { return disk.Exists() } +func (d *DiskAdapter) Create() error { return disk.Create() } +func (d *DiskAdapter) Mount() error { return disk.Mount() } +func (d *DiskAdapter) Unmount() error { return disk.Unmount() } +func (d *DiskAdapter) CleanArtifacts() { disk.CleanArtifacts() } +func (d *DiskAdapter) BackingFilePath() string { return disk.BackingFile } + +// GadgetAdapter wraps the gadget package functions. +type GadgetAdapter struct{} + +func (g *GadgetAdapter) Enable(backingFile string) error { return gadget.Enable(backingFile) } +func (g *GadgetAdapter) Disable() error { return gadget.Disable() } +func (g *GadgetAdapter) WaitForIdle() error { return gadget.WaitForIdle() } + +// ArchiveAdapter wraps the archive package functions. +type ArchiveAdapter struct{} + +func (a *ArchiveAdapter) IsReachable() bool { return archive.IsReachable() } +func (a *ArchiveAdapter) MountArchive() error { return archive.MountArchive() } +func (a *ArchiveAdapter) UnmountArchive() { archive.UnmountArchive() } +func (a *ArchiveAdapter) ArchiveClips(ctx context.Context) (int, int64, error) { return archive.ArchiveClips(ctx) } +func (a *ArchiveAdapter) ManageFreeSpace() { archive.ManageFreeSpace() } + +// SystemAdapter wraps the system package functions. +type SystemAdapter struct{} + +func (s *SystemAdapter) SetLED(mode string) { system.SetLED(mode) } +func (s *SystemAdapter) SyncTime() { system.SyncTime() } + +// KeepAwakeAdapter dispatches keep-awake commands based on config. +type KeepAwakeAdapter struct{} + +func (k *KeepAwakeAdapter) Send(ctx context.Context, command string) { + cfg := config.Get() + if cfg == nil { + return + } + switch cfg.KeepAwake.Method { + case "ble": + if cfg.KeepAwake.VIN != "" { + if command == "stop" { + ble.SentryOff(cfg.KeepAwake.VIN) + } else { + ble.KeepAwake(cfg.KeepAwake.VIN) + } + } + case "webhook": + if cfg.KeepAwake.WebhookURL != "" { + webhook.SendRaw(ctx, cfg.KeepAwake.WebhookURL, map[string]string{"awake_command": command}) + } + } +} + +// NotifyAdapter wraps the notify package. +type NotifyAdapter struct{} + +func (n *NotifyAdapter) Send(ctx context.Context, event webhook.Event) { + notify.Send(ctx, event) +} + +// NewDeps creates a Deps struct wired to the real implementations. +func NewDeps() state.Deps { + return state.Deps{ + Disk: &DiskAdapter{}, + Gadget: &GadgetAdapter{}, + Archive: &ArchiveAdapter{}, + System: &SystemAdapter{}, + KeepAwake: &KeepAwakeAdapter{}, + Notify: &NotifyAdapter{}, + } +}