diff --git a/drivers/hyperv/hyperv.go b/drivers/hyperv/hyperv.go index 329f55d..37c4415 100644 --- a/drivers/hyperv/hyperv.go +++ b/drivers/hyperv/hyperv.go @@ -2,9 +2,13 @@ package hyperv import ( "encoding/json" + "errors" "fmt" + "io/ioutil" "net" "os" + "path/filepath" + "regexp" "strings" "time" @@ -19,6 +23,7 @@ import ( type Driver struct { *drivers.BaseDriver Boot2DockerURL string + WindowsVHDUrl string VSwitch string DiskSize int MemSize int @@ -26,6 +31,7 @@ type Driver struct { MacAddr string VLanID int DisableDynamicMemory bool + OS string } const ( @@ -35,6 +41,7 @@ const ( defaultVLanID = 0 defaultDisableDynamicMemory = false defaultSwitchID = "c08cb7b8-9b3c-408e-8e30-5e16a3aeb444" + defaultServerImageFilename = "hybrid-minikube-windows-server.vhdx" ) // NewDriver creates a new Hyper-v driver with default settings. @@ -43,6 +50,7 @@ func NewDriver(hostName, storePath string) *Driver { DiskSize: defaultDiskSize, MemSize: defaultMemory, CPU: defaultCPU, + WindowsVHDUrl: mcnutils.ConfigGuest.GetVHDUrl(), DisableDynamicMemory: defaultDisableDynamicMemory, BaseDriver: &drivers.BaseDriver{ MachineName: hostName, @@ -142,7 +150,7 @@ func (d *Driver) GetURL() (string, error) { func (d *Driver) GetState() (state.State, error) { stdout, err := cmdOut("(", "Hyper-V\\Get-VM", d.MachineName, ").state") if err != nil { - return state.None, fmt.Errorf("Failed to find the VM status") + return state.None, fmt.Errorf("failed to find the VM status") } resp := parseLines(stdout) @@ -186,20 +194,35 @@ func (d *Driver) PreCreateCheck() error { return err } - // Downloading boot2docker to cache should be done here to make sure + // Downloading boot2docker/windows-server to cache should be done here to make sure // that a download failure will not leave a machine half created. b2dutils := mcnutils.NewB2dUtils(d.StorePath) - err = b2dutils.UpdateISOCache(d.Boot2DockerURL) + + if mcnutils.ConfigGuest.GetGuestOS() != "windows" { + err = b2dutils.UpdateISOCache(d.Boot2DockerURL) + } else { + err = b2dutils.UpdateVHDCache(d.WindowsVHDUrl) + } + return err } func (d *Driver) Create() error { b2dutils := mcnutils.NewB2dUtils(d.StorePath) - if err := b2dutils.CopyIsoToMachineDir(d.Boot2DockerURL, d.MachineName); err != nil { - return err + + if mcnutils.ConfigGuest.GetGuestOS() == "windows" { + d.SSHUser = "Administrator" + if err := b2dutils.CopyWindowsVHDToMachineDir(d.WindowsVHDUrl, d.MachineName); err != nil { + return err + } + } else { + if err := b2dutils.CopyIsoToMachineDir(d.Boot2DockerURL, d.MachineName); err != nil { + return err + } } log.Infof("Creating SSH key...") + if err := ssh.GenerateSSHKey(d.GetSSHKeyPath()); err != nil { return err } @@ -214,18 +237,36 @@ func (d *Driver) Create() error { } log.Infof("Using switch %q", d.VSwitch) - diskImage, err := d.generateDiskImage() + if mcnutils.ConfigGuest.GetGuestOS() == "windows" { + log.Infof("Adding SSH key to the VHDX...") + if err := writeSSHKeyToVHDX(d.ResolveStorePath(defaultServerImageFilename), d.publicSSHKeyPath()); err != nil { + log.Errorf("Error creating disk image: %s", err) + return err + } + } + + var diskImage string + var err error + if mcnutils.ConfigGuest.GetGuestOS() != "windows" { + diskImage, err = d.generateDiskImage() + } if err != nil { return err } + vmGeneration := "1" + if mcnutils.ConfigGuest.GetGuestOS() == "windows" { + vmGeneration = "2" + } + if err := cmd("Hyper-V\\New-VM", d.MachineName, - "-Path", fmt.Sprintf("'%s'", d.ResolveStorePath(".")), "-SwitchName", quote(d.VSwitch), + "-Generation", quote(vmGeneration), "-MemoryStartupBytes", toMb(d.MemSize)); err != nil { return err } + if d.DisableDynamicMemory { if err := cmd("Hyper-V\\Set-VMMemory", "-VMName", d.MachineName, @@ -237,7 +278,8 @@ func (d *Driver) Create() error { if d.CPU > 1 { if err := cmd("Hyper-V\\Set-VMProcessor", d.MachineName, - "-Count", fmt.Sprintf("%d", d.CPU)); err != nil { + "-Count", fmt.Sprintf("%d", d.CPU), + "-ExposeVirtualizationExtensions", "$true"); err != nil { return err } } @@ -259,19 +301,29 @@ func (d *Driver) Create() error { } } - if err := cmd("Hyper-V\\Set-VMDvdDrive", - "-VMName", d.MachineName, - "-Path", quote(d.ResolveStorePath("boot2docker.iso"))); err != nil { - return err + if mcnutils.ConfigGuest.GetGuestOS() != "windows" { + log.Infof("Attaching ISO and disk...") + if err := cmd("Hyper-V\\Set-VMDvdDrive", + "-VMName", d.MachineName, + "-Path", quote(d.ResolveStorePath("boot2docker.iso"))); err != nil { + return err + } } - if err := cmd("Hyper-V\\Add-VMHardDiskDrive", - "-VMName", d.MachineName, - "-Path", quote(diskImage)); err != nil { - return err + if mcnutils.ConfigGuest.GetGuestOS() == "windows" { + if err := cmd("Hyper-V\\Add-VMHardDiskDrive", + "-VMName", d.MachineName, + "-Path", quote(d.ResolveStorePath("hybrid-minikube-windows-server.vhdx")), + "-ControllerType", "SCSI"); err != nil { + return err + } + } else { + if err := cmd("Hyper-V\\Add-VMHardDiskDrive", + "-VMName", d.MachineName, + "-Path", quote(diskImage)); err != nil { + return err + } } - - log.Infof("Starting VM...") return d.Start() } @@ -507,3 +559,57 @@ func (d *Driver) generateDiskImage() (string, error) { return diskImage, nil } +func writeSSHKeyToVHDX(vhdxPath, publicSSHKeyPath string) (retErr error) { + output, err := cmdOut( + "powershell", "-Command", + "(Get-DiskImage -ImagePath", quote(vhdxPath), "| Mount-DiskImage -PassThru) | Out-Null;", + "$diskNumber = (Get-DiskImage -ImagePath", quote(vhdxPath), "| Get-Disk).Number;", + "Set-Disk -Number $diskNumber -IsReadOnly $false;", + "(Get-Disk -Number $diskNumber | Get-Partition | Get-Volume).DriveLetter", + ) + if err != nil { + return fmt.Errorf("failed to mount VHDX and retrieve mount directory: %w", err) + } + + regex := regexp.MustCompile(`\s+|\r|\n`) + driveLetter := regex.ReplaceAllString(output, "") + + if driveLetter == "" { + log.Debugf("No drive letter assigned to VHDX") + return errors.New("no drive letter assigned to VHDX") + } + + mountDir := strings.TrimSpace(driveLetter) + ":" + string(os.PathSeparator) + + defer func() { + if unmountErr := cmd("Dismount-DiskImage", "-ImagePath", quote(vhdxPath)); unmountErr != nil { + retErr = errors.Join(retErr, fmt.Errorf("failed to unmount VHDX: %w", unmountErr)) + } + }() + + sshDir := filepath.Join(mountDir, "ProgramData", "ssh") + adminAuthKeys := filepath.Join(sshDir, "administrators_authorized_keys") + + pubKey, err := os.ReadFile(publicSSHKeyPath) + if err != nil { + return fmt.Errorf("failed to read public SSH key from %s: %w", publicSSHKeyPath, err) + } + + if _, err := os.Stat(mountDir); os.IsNotExist(err) { + return fmt.Errorf("mount point %s does not exist", mountDir) + } + + if err := os.MkdirAll(sshDir, 0755); err != nil { + return fmt.Errorf("failed to create SSH directory: %w", err) + } + + if err := ioutil.WriteFile(adminAuthKeys, pubKey, 0644); err != nil { + return fmt.Errorf("failed to write public key: %w", err) + } + + if err := cmd("icacls.exe", quote(adminAuthKeys), "/inheritance:r", "/grant", "Administrators:F", "/grant", "SYSTEM:F"); err != nil { + return fmt.Errorf("failed to set permissions on %s: %w", adminAuthKeys, err) + } + + return nil +} diff --git a/drivers/hyperv/powershell.go b/drivers/hyperv/powershell.go index 3447b07..be15cdd 100644 --- a/drivers/hyperv/powershell.go +++ b/drivers/hyperv/powershell.go @@ -61,7 +61,7 @@ func hypervAvailable() error { } resp := parseLines(stdout) - if resp == nil || len(resp) == 0 || resp[0] != "Hyper-V" { + if len(resp) == 0 || resp[0] != "Hyper-V" { return ErrNotInstalled } @@ -106,7 +106,7 @@ func isWindowsAdministrator() (bool, error) { } func quote(text string) string { - return fmt.Sprintf("'%s'", text) + return fmt.Sprintf(`"%s"`, text) } func toMb(value int) string { diff --git a/libmachine/examples/main.go b/libmachine/examples/main.go index 7decc83..f7e55de 100644 --- a/libmachine/examples/main.go +++ b/libmachine/examples/main.go @@ -37,7 +37,7 @@ func create() { return } - h, err := client.NewHost("virtualbox", data) + h, err := client.NewHost("virtualbox", "linux", data) if err != nil { log.Error(err) return @@ -82,7 +82,7 @@ func streaming() { return } - h, err := client.NewHost("virtualbox", data) + h, err := client.NewHost("virtualbox", "linux", data) if err != nil { log.Error(err) return diff --git a/libmachine/host/host.go b/libmachine/host/host.go index d36e187..ac28abc 100644 --- a/libmachine/host/host.go +++ b/libmachine/host/host.go @@ -44,6 +44,7 @@ type Host struct { HostOptions *Options Name string RawDriver []byte `json:"-"` + Guest Guest } type Options struct { @@ -61,6 +62,12 @@ type Metadata struct { HostOptions Options } +type Guest struct { + Name string + Version string + URL string +} + func ValidateHostName(name string) bool { return validHostNamePattern.MatchString(name) } diff --git a/libmachine/libmachine.go b/libmachine/libmachine.go index 5d0c8fb..cc138f7 100644 --- a/libmachine/libmachine.go +++ b/libmachine/libmachine.go @@ -12,7 +12,7 @@ import ( "github.com/docker/machine/libmachine/check" "github.com/docker/machine/libmachine/drivers" "github.com/docker/machine/libmachine/drivers/plugin/localbinary" - "github.com/docker/machine/libmachine/drivers/rpc" + rpcdriver "github.com/docker/machine/libmachine/drivers/rpc" "github.com/docker/machine/libmachine/engine" "github.com/docker/machine/libmachine/host" "github.com/docker/machine/libmachine/log" @@ -28,7 +28,8 @@ import ( type API interface { io.Closer - NewHost(driverName string, rawDriver []byte) (*host.Host, error) + NewHost(driverName string, guest host.Guest, rawDriver []byte) (*host.Host, error) + DefineGuest(h *host.Host) Create(h *host.Host) error persist.Store GetMachinesDir() string @@ -53,7 +54,7 @@ func NewClient(storePath, certsDir string) *Client { } } -func (api *Client) NewHost(driverName string, rawDriver []byte) (*host.Host, error) { +func (api *Client) NewHost(driverName string, guest host.Guest, rawDriver []byte) (*host.Host, error) { driver, err := api.clientDriverFactory.NewRPCClientDriver(driverName, rawDriver) if err != nil { return nil, err @@ -64,6 +65,7 @@ func (api *Client) NewHost(driverName string, rawDriver []byte) (*host.Host, err Name: driver.GetMachineName(), Driver: driver, DriverName: driver.DriverName(), + Guest: guest, HostOptions: &host.Options{ AuthOptions: &auth.Options{ CertDir: api.certsDir, @@ -113,11 +115,15 @@ func (api *Client) Load(name string) (*host.Host, error) { return h, nil } +func (api *Client) DefineGuest(h *host.Host) { + mcnutils.SetGuestUtil(h.Guest.Name, h.Guest.URL) +} + // Create is the wrapper method which covers all of the boilerplate around // actually creating, provisioning, and persisting an instance in the store. func (api *Client) Create(h *host.Host) error { if err := cert.BootstrapCertificates(h.AuthOptions()); err != nil { - return fmt.Errorf("Error generating certificates: %s", err) + return fmt.Errorf("error generating certificates: %s", err) } log.Info("Running pre-create checks...") @@ -129,13 +135,13 @@ func (api *Client) Create(h *host.Host) error { } if err := api.Save(h); err != nil { - return fmt.Errorf("Error saving host to store before attempting creation: %s", err) + return fmt.Errorf("error saving host to store before attempting creation: %s", err) } log.Info("Creating machine...") if err := api.performCreate(h); err != nil { - return fmt.Errorf("Error creating machine: %s", err) + return fmt.Errorf("error creating machine: %s", err) } log.Debug("Reticulating splines...") @@ -145,11 +151,11 @@ func (api *Client) Create(h *host.Host) error { func (api *Client) performCreate(h *host.Host) error { if err := h.Driver.Create(); err != nil { - return fmt.Errorf("Error in driver during machine creation: %s", err) + return fmt.Errorf("error in driver during machine creation: %s", err) } if err := api.Save(h); err != nil { - return fmt.Errorf("Error saving host to store after attempting creation: %s", err) + return fmt.Errorf("error saving host to store after attempting creation: %s", err) } // TODO: Not really a fan of just checking "none" or "ci-test" here. @@ -159,24 +165,24 @@ func (api *Client) performCreate(h *host.Host) error { log.Info("Waiting for machine to be running, this may take a few minutes...") if err := mcnutils.WaitFor(drivers.MachineInState(h.Driver, state.Running)); err != nil { - return fmt.Errorf("Error waiting for machine to be running: %s", err) + return fmt.Errorf("error waiting for machine to be running: %s", err) } log.Info("Detecting operating system of created instance...") provisioner, err := provision.DetectProvisioner(h.Driver) if err != nil { - return fmt.Errorf("Error detecting OS: %s", err) + return fmt.Errorf("error detecting OS: %s", err) } log.Infof("Provisioning with %s...", provisioner.String()) if err := provisioner.Provision(*h.HostOptions.SwarmOptions, *h.HostOptions.AuthOptions, *h.HostOptions.EngineOptions); err != nil { - return fmt.Errorf("Error running provisioning: %s", err) + return fmt.Errorf("error running provisioning: %s", err) } // We should check the connection to docker here log.Info("Checking connection to Docker...") if _, _, err = check.DefaultConnChecker.Check(h, false); err != nil { - return fmt.Errorf("Error checking the host: %s", err) + return fmt.Errorf("error checking the host: %s", err) } log.Info("Docker is up and running!") diff --git a/libmachine/libmachinetest/fake_api.go b/libmachine/libmachinetest/fake_api.go index 4566257..3555f5c 100644 --- a/libmachine/libmachinetest/fake_api.go +++ b/libmachine/libmachinetest/fake_api.go @@ -20,7 +20,7 @@ func (api *FakeAPI) Close() error { return nil } -func (api *FakeAPI) NewHost(driverName string, rawDriver []byte) (*host.Host, error) { +func (api *FakeAPI) NewHost(driverName string, guestOS string, rawDriver []byte) (*host.Host, error) { return nil, nil } diff --git a/libmachine/mcnutils/b2d.go b/libmachine/mcnutils/b2d.go index dacc1fc..409bc6b 100644 --- a/libmachine/mcnutils/b2d.go +++ b/libmachine/mcnutils/b2d.go @@ -22,11 +22,12 @@ import ( ) const ( - defaultURL = "https://api.github.com/repos/boot2docker/boot2docker/releases" - defaultISOFilename = "boot2docker.iso" - defaultVolumeIDOffset = int64(0x8028) - versionPrefix = "-v" - defaultVolumeIDLength = 32 + defaultURL = "https://api.github.com/repos/boot2docker/boot2docker/releases" + defaultISOFilename = "boot2docker.iso" + defaultServerImageFilename = "hybrid-minikube-windows-server.vhdx" + defaultVolumeIDOffset = int64(0x8028) + versionPrefix = "-v" + defaultVolumeIDLength = 32 ) var ( @@ -194,6 +195,9 @@ Consider specifying another storage driver (e.g. 'overlay') using '--engine-stor func (*b2dReleaseGetter) download(dir, file, isoURL string) error { u, err := url.Parse(isoURL) + if err != nil { + return err + } var src io.ReadCloser if u.Scheme == "file" || u.Scheme == "" { @@ -261,6 +265,10 @@ type iso interface { path() string // exists reports whether the ISO exists. exists() bool + // pathVHD returns the path of the VHD. + pathVHD() string + // hasVHD returns whether the server VHD exists. + hasVHD() bool // version returns version information of the ISO. version() (string, error) } @@ -269,6 +277,8 @@ type iso interface { type b2dISO struct { // path of Boot2Docker ISO commonIsoPath string + // path of Windows Server VHD + commonVHDPath string // offset and length of ISO volume ID // cf. http://serverfault.com/questions/361474/is-there-a-way-to-change-a-iso-files-volume-id-from-the-command-line @@ -292,6 +302,22 @@ func (b *b2dISO) exists() bool { return !os.IsNotExist(err) } +func (b *b2dISO) pathVHD() string { + if b == nil { + return "" + } + return b.commonVHDPath +} + +func (b *b2dISO) hasVHD() bool { + if b == nil { + return false + } + + _, err := os.Stat(b.commonVHDPath) + return !os.IsNotExist(err) +} + // version scans the volume ID in b and returns its version tag. func (b *b2dISO) version() (string, error) { if b == nil { @@ -314,7 +340,7 @@ func (b *b2dISO) version() (string, error) { versionIndex := strings.Index(trimmedVersion, versionPrefix) if versionIndex == -1 { - return "", fmt.Errorf("Did not find prefix %q in version string", versionPrefix) + return "", fmt.Errorf("did not find prefix %q in version string", versionPrefix) } // Original magic file string looks similar to this: "Boot2Docker-v0.1.0 " @@ -348,6 +374,7 @@ func NewB2dUtils(storePath string) *B2dUtils { releaseGetter: &b2dReleaseGetter{isoFilename: defaultISOFilename}, iso: &b2dISO{ commonIsoPath: filepath.Join(imgCachePath, defaultISOFilename), + commonVHDPath: filepath.Join(imgCachePath, defaultServerImageFilename), volumeIDOffset: defaultVolumeIDOffset, volumeIDLength: defaultVolumeIDLength, }, @@ -356,12 +383,22 @@ func NewB2dUtils(storePath string) *B2dUtils { } } +func (b *B2dUtils) GetImgCachePath() string { + return b.imgCachePath +} + // DownloadISO downloads boot2docker ISO image for the given tag and save it at dest. func (b *B2dUtils) DownloadISO(dir, file, isoURL string) error { log.Infof("Downloading %s from %s...", b.path(), isoURL) return b.download(dir, file, isoURL) } +// DownloadVHD downloads the Windows Server VHD image and saves it at dest. +func (b *B2dUtils) DownloadVHD(dir, file, vhdURL string) error { + log.Infof("Downloading %s from %s...", b.pathVHD(), vhdURL) + return b.download(dir, file, vhdURL) +} + type ReaderWithProgress struct { io.ReadCloser out io.Writer @@ -408,6 +445,31 @@ func (b *B2dUtils) DownloadISOFromURL(latestReleaseURL string) error { return b.DownloadISO(b.imgCachePath, b.filename(), latestReleaseURL) } +func (b *B2dUtils) UpdateVHDCache(defaultVHDUrl string) error { + // recreate the cache dir if it has been manually deleted + // this will already be taken care of by the UpdateISOCache method for linux ISO + + exists := b.hasVHD() + + if !exists { + log.Info("No default Windows Server VHD found locally, downloading the latest release...") + + filePath := filepath.Join(b.imgCachePath, defaultServerImageFilename) + + fmt.Printf("\n") + fmt.Printf(" * Downloading and caching Windows Server VHD image...\n") + fmt.Printf(" * This may take a while...\n") + err := DownloadVHDX(defaultVHDUrl, filePath, 16, 1) // Download using 16 parts + + if err != nil { + return fmt.Errorf("Error: %v", err) + } + log.Info("Windows Server VHD downloaded successfully") + } + + return nil +} + func (b *B2dUtils) UpdateISOCache(isoURL string) error { // recreate the cache dir if it has been manually deleted if _, err := os.Stat(b.imgCachePath); os.IsNotExist(err) { @@ -466,6 +528,24 @@ func (b *B2dUtils) CopyIsoToMachineDir(isoURL, machineName string) error { return b.DownloadISO(machineDir, b.filename(), downloadURL) } +func (b *B2dUtils) CopyWindowsVHDToMachineDir(VHDUrl, machineName string) error { + + if err := b.UpdateVHDCache(VHDUrl); err != nil { + return err + } + + machineDir := filepath.Join(b.storePath, "machines", machineName) + + windowsMachineVHDPath := filepath.Join(machineDir, defaultServerImageFilename) + + // cached location of the windows iso + windowsVHDPath := filepath.Join(b.imgCachePath, defaultServerImageFilename) + + log.Infof("Copying %s to %s...", windowsVHDPath, windowsMachineVHDPath) + return CopyFile(windowsVHDPath, windowsMachineVHDPath) + +} + // isLatest checks the latest release tag and // reports whether the local ISO cache is the latest version. // diff --git a/libmachine/mcnutils/b2d_test.go b/libmachine/mcnutils/b2d_test.go index cf1cfad..c243229 100644 --- a/libmachine/mcnutils/b2d_test.go +++ b/libmachine/mcnutils/b2d_test.go @@ -207,10 +207,12 @@ func (m *mockReleaseGetter) download(dir, file, isoURL string) error { } type mockISO struct { - isopath string - exist bool - ver string - verCh <-chan string + isopath string + exist bool + ver string + vhdpath string + vhdexist bool + verCh <-chan string } func (m *mockISO) path() string { @@ -221,6 +223,14 @@ func (m *mockISO) exists() bool { return m.exist } +func (m *mockISO) pathVHD() string { + return m.vhdpath +} + +func (m *mockISO) hasVHD() bool { + return m.vhdexist +} + func (m *mockISO) version() (string, error) { select { // receive version of a downloaded iso diff --git a/libmachine/mcnutils/download_vhd.go b/libmachine/mcnutils/download_vhd.go new file mode 100644 index 0000000..d7b6225 --- /dev/null +++ b/libmachine/mcnutils/download_vhd.go @@ -0,0 +1,238 @@ +package mcnutils + +import ( + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/docker/machine/libmachine/log" +) + +type ProgressWriter struct { + Total int64 + Downloaded int64 + mu sync.Mutex + TargetName string +} + +func NewProgressWriter(total int64, targetName string) *ProgressWriter { + return &ProgressWriter{Total: total, TargetName: targetName} +} + +func (pw *ProgressWriter) Write(p []byte) (int, error) { + n := len(p) + pw.mu.Lock() + pw.Downloaded += int64(n) + pw.printProgress() + pw.mu.Unlock() + return n, nil +} + +func (pw *ProgressWriter) printProgress() { + // Overwrite the same line with \r + fmt.Printf("\r > %s: %d / %d bytes complete", pw.TargetName, pw.Downloaded, pw.Total) +} + +// copyLocalFile copies from a local source path to destination, reporting progress. +func copyLocalFile(srcPath, dstPath string) error { + srcFile, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("failed to open local source file %q: %w", srcPath, err) + } + defer srcFile.Close() + + // Get total size + info, err := srcFile.Stat() + if err != nil { + return fmt.Errorf("failed to stat local source file %q: %w", srcPath, err) + } + totalSize := info.Size() + + outFile, err := os.Create(dstPath) + if err != nil { + return fmt.Errorf("failed to create destination file %q: %w", dstPath, err) + } + defer outFile.Close() + + pw := NewProgressWriter(totalSize, filepath.Base(dstPath)) + // Use TeeReader: read from srcFile, write to pw (for progress), and to outFile + _, err = io.Copy(outFile, io.TeeReader(srcFile, pw)) + if err != nil { + return fmt.Errorf("error copying local file to %q: %w", dstPath, err) + } + + // Final newline after progress + fmt.Printf("\n") + log.Infof("\t> Local copy complete: %s\n", dstPath) + return nil +} + +// DownloadPart downloads a byte-range [start,end] of the URL into a temporary part file. +// On error, it returns the error; progress for the range is reported via pw. +// The caller goroutine must call wg.Done() exactly once. +func DownloadPart(urlStr string, start, end int64, partFileName string, pw *ProgressWriter, retryLimit int) error { + var resp *http.Response + var err error + + // Retry loop for downloading a part + for retries := 0; retries <= retryLimit; retries++ { + req, errReq := http.NewRequest("GET", urlStr, nil) + if errReq != nil { + log.Errorf("Error creating request: %v", errReq) + return errReq + } + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + + resp, err = http.DefaultClient.Do(req) + if err != nil { + log.Errorf("Error downloading part (attempt %d): %v", retries+1, err) + // Exponential backoff before retry + time.Sleep(time.Duration(1< retry + log.Errorf("Error: expected status 206 Partial Content, got %d", resp.StatusCode) + resp.Body.Close() + time.Sleep(time.Duration(1< 2 && localPath[1] == ':' { + localPath = localPath[1:] + } + return copyLocalFile(localPath, filePath) + } + } + // If no scheme or non-file scheme, check if it's a path existing on disk: + if fi, err := os.Stat(urlStr); err == nil && !fi.IsDir() { + // Treat as local file + return copyLocalFile(urlStr, filePath) + } + + // Otherwise assume HTTP(S) URL: + // First, HEAD to get total size + resp, err := http.Head(urlStr) + if err != nil { + return fmt.Errorf("failed to get file info from URL %q: %w", urlStr, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status from HEAD %q: %s", urlStr, resp.Status) + } + totalSize := resp.ContentLength + if totalSize <= 0 { + return fmt.Errorf("unknown content length for URL %q", urlStr) + } + + // For progress display: use base name of destination + pw := NewProgressWriter(totalSize, filepath.Base(filePath)) + + // Partition download into numParts + partSize := totalSize / int64(numParts) + var wg sync.WaitGroup + var muErr sync.Mutex + downloadErrors := make([]error, 0, numParts) + partFiles := make([]string, numParts) + + for i := 0; i < numParts; i++ { + start := int64(i) * partSize + end := start + partSize - 1 + if i == numParts-1 { + end = totalSize - 1 + } + partFileName := fmt.Sprintf("%s.part-%d.tmp", filePath, i) + partFiles[i] = partFileName + + wg.Add(1) + go func(idx int, s, e int64, pfn string) { + defer wg.Done() + errPart := DownloadPart(urlStr, s, e, pfn, pw, retryLimit) + if errPart != nil { + muErr.Lock() + downloadErrors = append(downloadErrors, fmt.Errorf("part %d: %w", idx, errPart)) + muErr.Unlock() + } + }(i, start, end, partFileName) + } + + // Wait for all parts + wg.Wait() + + if len(downloadErrors) > 0 { + // Clean up partial files + for _, pfn := range partFiles { + os.Remove(pfn) + } + return fmt.Errorf("download failed for parts: %v", downloadErrors) + } + + // Merge parts + outFile, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create destination file %q: %w", filePath, err) + } + defer outFile.Close() + + for _, pfn := range partFiles { + f, errOpen := os.Open(pfn) + if errOpen != nil { + return fmt.Errorf("failed to open part file %q: %w", pfn, errOpen) + } + _, errCopy := io.Copy(outFile, f) + f.Close() + if errCopy != nil { + return fmt.Errorf("failed to merge part file %q: %w", pfn, errCopy) + } + os.Remove(pfn) + } + + // Final newline after progress + fmt.Printf("\n") + log.Infof("\t> Download complete: %s\n", filePath) + return nil +} diff --git a/libmachine/mcnutils/utils.go b/libmachine/mcnutils/utils.go index a3992ed..bd9e028 100644 --- a/libmachine/mcnutils/utils.go +++ b/libmachine/mcnutils/utils.go @@ -9,12 +9,46 @@ import ( "runtime" "strconv" "time" + + "github.com/docker/machine/libmachine/log" ) type MultiError struct { Errs []error } +type GuestUtil struct { + os string + vhdUrl string +} + +// ConfigGuest is the package-level singleton for GuestUtil +var ConfigGuest *GuestUtil + +func SetGuestUtil(guestOS, vhdUrl string) { + ConfigGuest = &GuestUtil{ + os: guestOS, + vhdUrl: vhdUrl, + } + log.Debugf("SetGuestUtil: os=%s, vhdUrl=%s", guestOS, vhdUrl) +} + +func (g *GuestUtil) GetGuestOS() string { + if g == nil { + log.Debugf("GuestUtil is not initialized") + return "unknown" + } + return g.os +} + +func (g *GuestUtil) GetVHDUrl() string { + if g == nil { + log.Debugf("GuestUtil is not initialized") + return "" + } + return g.vhdUrl +} + func (e MultiError) Error() string { aggregate := "" for _, err := range e.Errs { @@ -89,7 +123,7 @@ func WaitForSpecificOrError(f func() (bool, error), maxAttempts int, waitInterva } time.Sleep(waitInterval) } - return fmt.Errorf("Maximum number of retries (%d) exceeded", maxAttempts) + return fmt.Errorf("maximum number of retries (%d) exceeded", maxAttempts) } func WaitForSpecific(f func() bool, maxAttempts int, waitInterval time.Duration) error { diff --git a/libmachine/ssh/client.go b/libmachine/ssh/client.go index 1f552fc..c7cb431 100644 --- a/libmachine/ssh/client.go +++ b/libmachine/ssh/client.go @@ -168,14 +168,14 @@ func (client *NativeClient) dialSuccess() bool { return true } -func (client *NativeClient) session(command string) (*ssh.Client, *ssh.Session, error) { +func (client *NativeClient) session(_ string) (*ssh.Client, *ssh.Session, error) { if err := mcnutils.WaitFor(client.dialSuccess); err != nil { - return nil, nil, fmt.Errorf("Error attempting SSH client dial: %s", err) + return nil, nil, fmt.Errorf("error attempting SSH client dial: %s", err) } conn, err := ssh.Dial("tcp", net.JoinHostPort(client.Hostname, strconv.Itoa(client.Port)), &client.Config) if err != nil { - return nil, nil, fmt.Errorf("Mysterious error dialing TCP for SSH (we already succeeded at least once) : %s", err) + return nil, nil, fmt.Errorf("mysterious error dialing TCP for SSH (we already succeeded at least once): %s", err) } session, err := conn.NewSession()