@@ -9,11 +9,14 @@ import (
99 "html"
1010 "io"
1111 "net/http"
12+ "os"
13+ "runtime"
1214 "strconv"
1315 "strings"
16+ "syscall"
1417 "time"
18+ "unsafe"
1519
16- "github.com/docker/go-units"
1720 "github.com/docker/model-distribution/distribution"
1821 "github.com/docker/model-runner/pkg/inference"
1922 dmrm "github.com/docker/model-runner/pkg/inference/models"
@@ -106,6 +109,154 @@ func (c *Client) Status() Status {
106109 }
107110}
108111
112+ func humanReadableSize (size float64 ) string {
113+ units := []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" }
114+ i := 0
115+ for size >= 1024 && i < len (units )- 1 {
116+ size /= 1024
117+ i ++
118+ }
119+ return fmt .Sprintf ("%.2f%s" , size , units [i ])
120+ }
121+
122+ func humanReadableSizePad (size float64 , width int ) string {
123+ return fmt .Sprintf ("%*s" , width , humanReadableSize (size ))
124+ }
125+
126+ func humanReadableTimePad (seconds int64 , width int ) string {
127+ var s string
128+ if seconds < 60 {
129+ s = fmt .Sprintf ("%ds" , seconds )
130+ } else if seconds < 3600 {
131+ s = fmt .Sprintf ("%dm %02ds" , seconds / 60 , seconds % 60 )
132+ } else {
133+ s = fmt .Sprintf ("%dh %02dm %02ds" , seconds / 3600 , (seconds % 3600 )/ 60 , seconds % 60 )
134+ }
135+ return fmt .Sprintf ("%*s" , width , s )
136+ }
137+
138+ // ProgressBarState tracks the running totals and timing for speed/ETA
139+ type ProgressBarState struct {
140+ LastDownloaded uint64
141+ LastTime time.Time
142+ StartTime time.Time
143+ UpdateInterval time.Duration // New: interval between updates
144+ lastPrint time.Time // New: last time the progress bar was printed
145+ }
146+
147+ // formatBar calculates the bar width and filled bar string.
148+ func (pbs * ProgressBarState ) formatBar (percent float64 , termWidth int , prefix , suffix string ) string {
149+ barWidth := termWidth - len (prefix ) - len (suffix ) - 4
150+ if barWidth < 10 {
151+ barWidth = 10
152+ }
153+ filled := int (percent / 100 * float64 (barWidth ))
154+ if filled > barWidth {
155+ filled = barWidth
156+ }
157+ bar := strings .Repeat ("█" , filled ) + strings .Repeat (" " , barWidth - filled )
158+ return bar
159+ }
160+
161+ // calcSpeed calculates the current download speed.
162+ func (pbs * ProgressBarState ) calcSpeed (current uint64 , now time.Time ) float64 {
163+ elapsed := now .Sub (pbs .LastTime ).Seconds ()
164+ if elapsed <= 0 {
165+ return 0
166+ }
167+
168+ speed := float64 (current - pbs .LastDownloaded ) / elapsed
169+ pbs .LastTime = now
170+ pbs .LastDownloaded = current
171+
172+ return speed
173+ }
174+
175+ // formatSuffix returns the suffix string showing human readable sizes, speed, and ETA.
176+ func (pbs * ProgressBarState ) fmtSuffix (current , total uint64 , speed float64 , eta int64 ) string {
177+ return fmt .Sprintf ("%s/%s %s/s %s" ,
178+ humanReadableSizePad (float64 (current ), 10 ),
179+ humanReadableSize (float64 (total )),
180+ humanReadableSizePad (speed , 10 ),
181+ humanReadableTimePad (eta , 16 ),
182+ )
183+ }
184+
185+ // calcETA calculates the estimated time remaining.
186+ func (pbs * ProgressBarState ) calcETA (current , total uint64 , speed float64 ) int64 {
187+ if speed <= 0 {
188+ return 0
189+ }
190+ return int64 (float64 (total - current ) / speed )
191+ }
192+
193+ // printProgressBar prints/updates a progress bar in the terminal
194+ // Only prints if UpdateInterval has passed since last print, or always if interval=0
195+ func (pbs * ProgressBarState ) printProgressBar (current , total uint64 ) {
196+ if pbs .StartTime .IsZero () {
197+ pbs .StartTime = time .Now ()
198+ pbs .LastTime = pbs .StartTime
199+ pbs .LastDownloaded = current
200+ pbs .lastPrint = pbs .StartTime
201+ }
202+
203+ now := time .Now ()
204+ // Only update display if enough time passed,
205+ // unless interval is 0 (always print)
206+ if pbs .UpdateInterval > 0 && now .Sub (pbs .lastPrint ) < pbs .UpdateInterval && current != total {
207+ return
208+ }
209+
210+ pbs .lastPrint = now
211+ termWidth := getTerminalWidth ()
212+ percent := float64 (current ) / float64 (total ) * 100
213+ prefix := fmt .Sprintf ("%3.0f%% |" , percent )
214+ speed := pbs .calcSpeed (current , now )
215+ eta := pbs .calcETA (current , total , speed )
216+ suffix := pbs .fmtSuffix (current , total , speed , eta )
217+ bar := pbs .formatBar (percent , termWidth , prefix , suffix )
218+ fmt .Fprintf (os .Stderr , "\r %s%s| %s" , prefix , bar , suffix )
219+ }
220+
221+ func getTerminalWidthUnix () (int , error ) {
222+ type winsize struct {
223+ Row uint16
224+ Col uint16
225+ Xpixel uint16
226+ Ypixel uint16
227+ }
228+ ws := & winsize {}
229+ retCode , _ , errno := syscall .Syscall6 (
230+ syscall .SYS_IOCTL ,
231+ uintptr (os .Stdout .Fd ()),
232+ uintptr (syscall .TIOCGWINSZ ),
233+ uintptr (unsafe .Pointer (ws )),
234+ 0 , 0 , 0 ,
235+ )
236+ if int (retCode ) == - 1 {
237+ return 0 , errno
238+ }
239+ return int (ws .Col ), nil
240+ }
241+
242+ // getTerminalSize tries to get the terminal width (default 80 if fails)
243+ func getTerminalWidth () int {
244+ var width int
245+ var err error
246+ default_width := 80
247+ if runtime .GOOS == "windows" { // to be implemented
248+ return default_width
249+ }
250+
251+ width , err = getTerminalWidthUnix ()
252+ if width == 0 || err != nil {
253+ return default_width
254+ }
255+
256+ return width
257+ }
258+
259+
109260func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , progress func (string )) (string , bool , error ) {
110261 model = normalizeHuggingFaceModelName (model )
111262 jsonData , err := json .Marshal (dmrm.ModelCreateRequest {From : model , IgnoreRuntimeMemoryCheck : ignoreRuntimeMemoryCheck })
@@ -130,10 +281,14 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
130281 }
131282
132283 progressShown := false
133- current := uint64 (0 ) // Track cumulative progress across all layers
284+ // Track cumulative progress across all layers
285+ current := uint64 (0 )
134286 layerProgress := make (map [string ]uint64 ) // Track progress per layer ID
135287
136288 scanner := bufio .NewScanner (resp .Body )
289+ pbs := & ProgressBarState {
290+ UpdateInterval : time .Millisecond * 100 ,
291+ }
137292 for scanner .Scan () {
138293 progressLine := scanner .Text ()
139294 if progressLine == "" {
@@ -159,7 +314,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
159314 current += layerCurrent
160315 }
161316
162- progress ( fmt . Sprintf ( "Downloaded %s of %s" , units . CustomSize ( "%.2f%s" , float64 ( current ), 1000.0 , [] string { "B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" }), units . CustomSize ( "%.2f%s" , float64 ( progressMsg .Total ), 1000.0 , [] string { "B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })) )
317+ pbs . printProgressBar ( current , progressMsg .Total )
163318 progressShown = true
164319 case "error" :
165320 return "" , progressShown , fmt .Errorf ("error pulling model: %s" , progressMsg .Message )
0 commit comments