Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.

Commit 3a0fab7

Browse files
committed
Progress bar with more info
More info than previous progress bar Signed-off-by: Eric Curtin <[email protected]>
1 parent 08a8afe commit 3a0fab7

File tree

1 file changed

+158
-3
lines changed

1 file changed

+158
-3
lines changed

desktop/desktop.go

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@ import (
99
"html"
1010
"io"
1111
"net/http"
12+
"os"
13+
"runtime"
1214
"strconv"
1315
"strings"
16+
"syscall"
1417
"time"
18+
"unsafe"
1519

16-
"github.com/docker/go-units"
1720
"github.com/docker/model-distribution/distribution"
1821
"github.com/docker/model-runner/pkg/inference"
1922
dmrm "github.com/docker/model-runner/pkg/inference/models"
@@ -106,6 +109,154 @@ func (c *Client) Status() Status {
106109
}
107110
}
108111

112+
func humanReadableSize(size float64) string {
113+
units := []string{"B", "kB", "MB", "GB", "TB", "PB"}
114+
i := 0
115+
for size >= 1024 && i < len(units)-1 {
116+
size /= 1024
117+
i++
118+
}
119+
return fmt.Sprintf("%.2f%s", size, units[i])
120+
}
121+
122+
func humanReadableSizePad(size float64, width int) string {
123+
return fmt.Sprintf("%*s", width, humanReadableSize(size))
124+
}
125+
126+
func humanReadableTimePad(seconds int64, width int) string {
127+
var s string
128+
if seconds < 60 {
129+
s = fmt.Sprintf("%ds", seconds)
130+
} else if seconds < 3600 {
131+
s = fmt.Sprintf("%dm %02ds", seconds/60, seconds%60)
132+
} else {
133+
s = fmt.Sprintf("%dh %02dm %02ds", seconds/3600, (seconds%3600)/60, seconds%60)
134+
}
135+
return fmt.Sprintf("%*s", width, s)
136+
}
137+
138+
// ProgressBarState tracks the running totals and timing for speed/ETA
139+
type ProgressBarState struct {
140+
LastDownloaded uint64
141+
LastTime time.Time
142+
StartTime time.Time
143+
UpdateInterval time.Duration // New: interval between updates
144+
lastPrint time.Time // New: last time the progress bar was printed
145+
}
146+
147+
// formatBar calculates the bar width and filled bar string.
148+
func (pbs *ProgressBarState) formatBar(percent float64, termWidth int, prefix, suffix string) string {
149+
barWidth := termWidth - len(prefix) - len(suffix) - 4
150+
if barWidth < 10 {
151+
barWidth = 10
152+
}
153+
filled := int(percent / 100 * float64(barWidth))
154+
if filled > barWidth {
155+
filled = barWidth
156+
}
157+
bar := strings.Repeat("█", filled) + strings.Repeat(" ", barWidth-filled)
158+
return bar
159+
}
160+
161+
// calcSpeed calculates the current download speed.
162+
func (pbs *ProgressBarState) calcSpeed(current uint64, now time.Time) float64 {
163+
elapsed := now.Sub(pbs.LastTime).Seconds()
164+
if elapsed <= 0 {
165+
return 0
166+
}
167+
168+
speed := float64(current-pbs.LastDownloaded) / elapsed
169+
pbs.LastTime = now
170+
pbs.LastDownloaded = current
171+
172+
return speed
173+
}
174+
175+
// formatSuffix returns the suffix string showing human readable sizes, speed, and ETA.
176+
func (pbs *ProgressBarState) fmtSuffix(current, total uint64, speed float64, eta int64) string {
177+
return fmt.Sprintf("%s/%s %s/s %s",
178+
humanReadableSizePad(float64(current), 10),
179+
humanReadableSize(float64(total)),
180+
humanReadableSizePad(speed, 10),
181+
humanReadableTimePad(eta, 16),
182+
)
183+
}
184+
185+
// calcETA calculates the estimated time remaining.
186+
func (pbs *ProgressBarState) calcETA(current, total uint64, speed float64) int64 {
187+
if speed <= 0 {
188+
return 0
189+
}
190+
return int64(float64(total-current) / speed)
191+
}
192+
193+
// printProgressBar prints/updates a progress bar in the terminal
194+
// Only prints if UpdateInterval has passed since last print, or always if interval=0
195+
func (pbs *ProgressBarState) printProgressBar(current, total uint64) {
196+
if pbs.StartTime.IsZero() {
197+
pbs.StartTime = time.Now()
198+
pbs.LastTime = pbs.StartTime
199+
pbs.LastDownloaded = current
200+
pbs.lastPrint = pbs.StartTime
201+
}
202+
203+
now := time.Now()
204+
// Only update display if enough time passed,
205+
// unless interval is 0 (always print)
206+
if pbs.UpdateInterval > 0 && now.Sub(pbs.lastPrint) < pbs.UpdateInterval && current != total {
207+
return
208+
}
209+
210+
pbs.lastPrint = now
211+
termWidth := getTerminalWidth()
212+
percent := float64(current) / float64(total) * 100
213+
prefix := fmt.Sprintf("%3.0f%% |", percent)
214+
speed := pbs.calcSpeed(current, now)
215+
eta := pbs.calcETA(current, total, speed)
216+
suffix := pbs.fmtSuffix(current, total, speed, eta)
217+
bar := pbs.formatBar(percent, termWidth, prefix, suffix)
218+
fmt.Fprintf(os.Stderr, "\r%s%s| %s", prefix, bar, suffix)
219+
}
220+
221+
func getTerminalWidthUnix() (int, error) {
222+
type winsize struct {
223+
Row uint16
224+
Col uint16
225+
Xpixel uint16
226+
Ypixel uint16
227+
}
228+
ws := &winsize{}
229+
retCode, _, errno := syscall.Syscall6(
230+
syscall.SYS_IOCTL,
231+
uintptr(os.Stdout.Fd()),
232+
uintptr(syscall.TIOCGWINSZ),
233+
uintptr(unsafe.Pointer(ws)),
234+
0, 0, 0,
235+
)
236+
if int(retCode) == -1 {
237+
return 0, errno
238+
}
239+
return int(ws.Col), nil
240+
}
241+
242+
// getTerminalSize tries to get the terminal width (default 80 if fails)
243+
func getTerminalWidth() int {
244+
var width int
245+
var err error
246+
default_width := 80
247+
if runtime.GOOS == "windows" { // to be implemented
248+
return default_width
249+
}
250+
251+
width, err = getTerminalWidthUnix()
252+
if width == 0 || err != nil {
253+
return default_width
254+
}
255+
256+
return width
257+
}
258+
259+
109260
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) {
110261
model = normalizeHuggingFaceModelName(model)
111262
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
@@ -130,10 +281,14 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
130281
}
131282

132283
progressShown := false
133-
current := uint64(0) // Track cumulative progress across all layers
284+
// Track cumulative progress across all layers
285+
current := uint64(0)
134286
layerProgress := make(map[string]uint64) // Track progress per layer ID
135287

136288
scanner := bufio.NewScanner(resp.Body)
289+
pbs := &ProgressBarState{
290+
UpdateInterval: time.Millisecond * 100,
291+
}
137292
for scanner.Scan() {
138293
progressLine := scanner.Text()
139294
if progressLine == "" {
@@ -159,7 +314,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
159314
current += layerCurrent
160315
}
161316

162-
progress(fmt.Sprintf("Downloaded %s of %s", units.CustomSize("%.2f%s", float64(current), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}), units.CustomSize("%.2f%s", float64(progressMsg.Total), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"})))
317+
pbs.printProgressBar(current, progressMsg.Total)
163318
progressShown = true
164319
case "error":
165320
return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message)

0 commit comments

Comments
 (0)