From a5878b46ecade7b87153848e136e9db1f2dd1774 Mon Sep 17 00:00:00 2001 From: Valeriy Selitskiy <239034+iamwavecut@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:45:38 +0200 Subject: [PATCH] Add file source handling and upload payload management --- bot.go | 252 ++++++++-------- configs.go | 461 ++++++++++++++++++------------ docs/internals/uploading-files.md | 48 ++-- file_source.go | 170 +++++++++++ file_source_test.go | 108 +++++++ upload_payload.go | 98 +++++++ upload_payload_test.go | 177 ++++++++++++ 7 files changed, 980 insertions(+), 334 deletions(-) create mode 100644 file_source.go create mode 100644 file_source_test.go create mode 100644 upload_payload.go create mode 100644 upload_payload_test.go diff --git a/bot.go b/bot.go index fb8f9b42..165a5069 100644 --- a/bot.go +++ b/bot.go @@ -36,6 +36,23 @@ type BotAPI struct { mu sync.RWMutex } +type requestPayload struct { + body io.Reader + closer io.Closer + contentType string +} + +func (p requestPayload) close() { + if p.closer != nil { + _ = p.closer.Close() + } +} + +type requestDebug struct { + params Params + fileCount int +} + // NewBotAPI creates a new BotAPI instance. // // It requires a token, provided by @BotFather on Telegram. @@ -93,25 +110,106 @@ func buildParams(in Params) url.Values { return out } -// MakeRequest makes a request to a specific endpoint with our token. -func (bot *BotAPI) MakeRequest(endpoint string, params Params) (*APIResponse, error) { - return bot.MakeRequestWithContext(context.Background(), endpoint, params) +func buildFormPayload(params Params) requestPayload { + values := buildParams(params) + reader := strings.NewReader(values.Encode()) + return requestPayload{ + body: reader, + contentType: "application/x-www-form-urlencoded", + } } -func (bot *BotAPI) MakeRequestWithContext(ctx context.Context, endpoint string, params Params) (*APIResponse, error) { +func buildMultipartPayload(params Params, files []RequestFile) (requestPayload, error) { + reader, writer := io.Pipe() + multipartWriter := multipart.NewWriter(writer) + + go func() { + defer writer.Close() + defer multipartWriter.Close() + + for field, value := range params { + if err := multipartWriter.WriteField(field, value); err != nil { + writer.CloseWithError(err) + return + } + } + + for _, file := range files { + source, err := resolveRequestFileData(file.Data) + if err != nil { + writer.CloseWithError(err) + return + } + + if source.kindIsUpload() { + desc, err := source.openUpload() + if err != nil { + writer.CloseWithError(err) + return + } + + part, err := multipartWriter.CreateFormFile(file.Name, desc.name) + if err != nil { + desc.reader.Close() + writer.CloseWithError(err) + return + } + + if _, err = io.Copy(part, desc.reader); err != nil { + desc.reader.Close() + writer.CloseWithError(err) + return + } + + if err = desc.reader.Close(); err != nil { + writer.CloseWithError(err) + return + } + + continue + } + + value, err := source.referenceValue() + if err != nil { + writer.CloseWithError(err) + return + } + + if err = multipartWriter.WriteField(file.Name, value); err != nil { + writer.CloseWithError(err) + return + } + } + }() + + return requestPayload{ + body: reader, + closer: reader, + contentType: multipartWriter.FormDataContentType(), + }, nil +} + +func (bot *BotAPI) executeRequest(ctx context.Context, endpoint string, payload requestPayload, debugInfo requestDebug) (*APIResponse, error) { + defer payload.close() + if bot.Debug { - log.Printf("Endpoint: %s, params: %v\n", endpoint, params) + if debugInfo.fileCount > 0 { + log.Printf("Endpoint: %s, params: %v, with %d files\n", endpoint, debugInfo.params, debugInfo.fileCount) + } else { + log.Printf("Endpoint: %s, params: %v\n", endpoint, debugInfo.params) + } } method := fmt.Sprintf(bot.apiEndpoint, bot.Token, endpoint) - values := buildParams(params) - - req, err := http.NewRequestWithContext(ctx, "POST", method, strings.NewReader(values.Encode())) + req, err := http.NewRequestWithContext(ctx, "POST", method, payload.body) if err != nil { return &APIResponse{}, err } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + if payload.contentType != "" { + req.Header.Set("Content-Type", payload.contentType) + } resp, err := bot.Client.Do(req) if err != nil { @@ -146,6 +244,16 @@ func (bot *BotAPI) MakeRequestWithContext(ctx context.Context, endpoint string, return &apiResp, nil } +// MakeRequest makes a request to a specific endpoint with our token. +func (bot *BotAPI) MakeRequest(endpoint string, params Params) (*APIResponse, error) { + return bot.MakeRequestWithContext(context.Background(), endpoint, params) +} + +func (bot *BotAPI) MakeRequestWithContext(ctx context.Context, endpoint string, params Params) (*APIResponse, error) { + payload := buildFormPayload(params) + return bot.executeRequest(ctx, endpoint, payload, requestDebug{params: params}) +} + // decodeAPIResponse decode response and return slice of bytes if debug enabled. // If debug disabled, just decode http.Response.Body stream to APIResponse struct // for efficient memory usage @@ -176,101 +284,15 @@ func (bot *BotAPI) UploadFiles(endpoint string, params Params, files []RequestFi } func (bot *BotAPI) UploadFilesWithContext(ctx context.Context, endpoint string, params Params, files []RequestFile) (*APIResponse, error) { - r, w := io.Pipe() - m := multipart.NewWriter(w) - - // This code modified from the very helpful @HirbodBehnam - // https://github.com/go-telegram-bot-api/telegram-bot-api/issues/354#issuecomment-663856473 - go func() { - defer w.Close() - defer m.Close() - - for field, value := range params { - if err := m.WriteField(field, value); err != nil { - w.CloseWithError(err) - return - } - } - - for _, file := range files { - if file.Data.NeedsUpload() { - name, reader, err := file.Data.UploadData() - if err != nil { - w.CloseWithError(err) - return - } - - part, err := m.CreateFormFile(file.Name, name) - if err != nil { - w.CloseWithError(err) - return - } - - if _, err := io.Copy(part, reader); err != nil { - w.CloseWithError(err) - return - } - - if closer, ok := reader.(io.ReadCloser); ok { - if err = closer.Close(); err != nil { - w.CloseWithError(err) - return - } - } - } else { - value := file.Data.SendData() - - if err := m.WriteField(file.Name, value); err != nil { - w.CloseWithError(err) - return - } - } - } - }() - - if bot.Debug { - log.Printf("Endpoint: %s, params: %v, with %d files\n", endpoint, params, len(files)) - } - - method := fmt.Sprintf(bot.apiEndpoint, bot.Token, endpoint) - - req, err := http.NewRequestWithContext(ctx, "POST", method, r) + payload, err := buildMultipartPayload(params, files) if err != nil { return nil, err } - req.Header.Set("Content-Type", m.FormDataContentType()) - - resp, err := bot.Client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var apiResp APIResponse - bytes, err := bot.decodeAPIResponse(resp.Body, &apiResp) - if err != nil { - return &apiResp, err - } - - if bot.Debug { - log.Printf("Endpoint: %s, response: %s\n", endpoint, string(bytes)) - } - - if !apiResp.Ok { - var parameters ResponseParameters - - if apiResp.Parameters != nil { - parameters = *apiResp.Parameters - } - - return &apiResp, &Error{ - Message: apiResp.Description, - ResponseParameters: parameters, - } - } - - return &apiResp, nil + return bot.executeRequest(ctx, endpoint, payload, requestDebug{ + params: params, + fileCount: len(files), + }) } // GetFileDirectURL returns direct URL to file @@ -313,16 +335,6 @@ func (bot *BotAPI) IsMessageToMe(message Message) bool { return strings.Contains(message.Text, "@"+bot.Self.UserName) } -func hasFilesNeedingUpload(files []RequestFile) bool { - for _, file := range files { - if file.Data.NeedsUpload() { - return true - } - } - - return false -} - // Request sends a Chattable to Telegram, and returns the APIResponse. func (bot *BotAPI) Request(c Chattable) (*APIResponse, error) { return bot.RequestWithContext(context.Background(), c) @@ -335,18 +347,11 @@ func (bot *BotAPI) RequestWithContext(ctx context.Context, c Chattable) (*APIRes } if t, ok := c.(Fileable); ok { - files := t.files() + payload := payloadFromFileable(t) + params = payload.applyInline(params) - // If we have files that need to be uploaded, we should delegate the - // request to UploadFile. - if hasFilesNeedingUpload(files) { - return bot.UploadFiles(t.method(), params, files) - } - - // However, if there are no files to be uploaded, there's likely things - // that need to be turned into params instead. - for _, file := range files { - params[file.Name] = file.Data.SendData() + if payload.needsUpload() { + return bot.UploadFilesWithContext(ctx, t.method(), params, payload.filesSlice()) } } @@ -576,9 +581,12 @@ func WriteToHTTPResponse(w http.ResponseWriter, c Chattable) error { } if t, ok := c.(Fileable); ok { - if hasFilesNeedingUpload(t.files()) { + payload := payloadFromFileable(t) + if payload.needsUpload() { return errors.New("unable to use http response to upload files") } + + params = payload.applyInline(params) } values := buildParams(params) diff --git a/configs.go b/configs.go index 669d8b0e..3d991153 100644 --- a/configs.go +++ b/configs.go @@ -1,11 +1,9 @@ package tgbotapi import ( - "bytes" "fmt" "io" "net/url" - "os" "strconv" ) @@ -175,15 +173,29 @@ type FileBytes struct { } func (fb FileBytes) NeedsUpload() bool { - return true + return fb.descriptor().kindIsUpload() } func (fb FileBytes) UploadData() (string, io.Reader, error) { - return fb.Name, bytes.NewReader(fb.Bytes), nil + desc, err := fb.descriptor().openUpload() + if err != nil { + return "", nil, err + } + + return desc.name, desc.reader, nil } func (fb FileBytes) SendData() string { - panic("FileBytes must be uploaded") + value, err := fb.descriptor().referenceValue() + if err != nil { + return "" + } + + return value +} + +func (fb FileBytes) descriptor() fileSource { + return newBytesSource(fb.Name, fb.Bytes) } // FileReader contains information about a reader to upload as a File. @@ -193,81 +205,145 @@ type FileReader struct { } func (fr FileReader) NeedsUpload() bool { - return true + return fr.descriptor().kindIsUpload() } func (fr FileReader) UploadData() (string, io.Reader, error) { - return fr.Name, fr.Reader, nil + desc, err := fr.descriptor().openUpload() + if err != nil { + return "", nil, err + } + + return desc.name, desc.reader, nil } func (fr FileReader) SendData() string { - panic("FileReader must be uploaded") + value, err := fr.descriptor().referenceValue() + if err != nil { + return "" + } + + return value +} + +func (fr FileReader) descriptor() fileSource { + return newReaderSource(fr.Name, fr.Reader) } // FilePath is a path to a local file. type FilePath string func (fp FilePath) NeedsUpload() bool { - return true + return fp.descriptor().kindIsUpload() } func (fp FilePath) UploadData() (string, io.Reader, error) { - fileHandle, err := os.Open(string(fp)) + desc, err := fp.descriptor().openUpload() if err != nil { return "", nil, err } - name := fileHandle.Name() - return name, fileHandle, err + return desc.name, desc.reader, nil } func (fp FilePath) SendData() string { - panic("FilePath must be uploaded") + value, err := fp.descriptor().referenceValue() + if err != nil { + return "" + } + + return value +} + +func (fp FilePath) descriptor() fileSource { + return newPathSource(string(fp)) } // FileURL is a URL to use as a file for a request. type FileURL string func (fu FileURL) NeedsUpload() bool { - return false + return fu.descriptor().kindIsUpload() } func (fu FileURL) UploadData() (string, io.Reader, error) { - panic("FileURL cannot be uploaded") + desc, err := fu.descriptor().openUpload() + if err != nil { + return "", nil, err + } + + return desc.name, desc.reader, nil } func (fu FileURL) SendData() string { - return string(fu) + value, err := fu.descriptor().referenceValue() + if err != nil { + return "" + } + + return value +} + +func (fu FileURL) descriptor() fileSource { + return newURLSource(string(fu)) } // FileID is an ID of a file already uploaded to Telegram. type FileID string func (fi FileID) NeedsUpload() bool { - return false + return fi.descriptor().kindIsUpload() } func (fi FileID) UploadData() (string, io.Reader, error) { - panic("FileID cannot be uploaded") + desc, err := fi.descriptor().openUpload() + if err != nil { + return "", nil, err + } + + return desc.name, desc.reader, nil } func (fi FileID) SendData() string { - return string(fi) + value, err := fi.descriptor().referenceValue() + if err != nil { + return "" + } + + return value +} + +func (fi FileID) descriptor() fileSource { + return newFileIDSource(string(fi)) } // fileAttach is an internal file type used for processed media groups. type fileAttach string func (fa fileAttach) NeedsUpload() bool { - return false + return fa.descriptor().kindIsUpload() } func (fa fileAttach) UploadData() (string, io.Reader, error) { - panic("fileAttach cannot be uploaded") + desc, err := fa.descriptor().openUpload() + if err != nil { + return "", nil, err + } + + return desc.name, desc.reader, nil } func (fa fileAttach) SendData() string { - return string(fa) + value, err := fa.descriptor().referenceValue() + if err != nil { + return "" + } + + return value +} + +func (fa fileAttach) descriptor() fileSource { + return newAttachSource(string(fa)) } // LogOutConfig is a request to log out of the cloud Bot API server. @@ -489,19 +565,14 @@ func (config PhotoConfig) method() string { } func (config PhotoConfig) files() []RequestFile { - files := []RequestFile{{ - Name: "photo", - Data: config.File, - }} - - if config.Thumb != nil { - files = append(files, RequestFile{ - Name: "thumbnail", - Data: config.Thumb, - }) - } + return config.filePayload().filesSlice() +} - return files +func (config PhotoConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("photo", config.File) + payload.Add("thumbnail", config.Thumb) + return payload } // AudioConfig contains information about a SendAudio request. @@ -537,19 +608,14 @@ func (config AudioConfig) method() string { } func (config AudioConfig) files() []RequestFile { - files := []RequestFile{{ - Name: "audio", - Data: config.File, - }} - - if config.Thumb != nil { - files = append(files, RequestFile{ - Name: "thumbnail", - Data: config.Thumb, - }) - } + return config.filePayload().filesSlice() +} - return files +func (config AudioConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("audio", config.File) + payload.Add("thumbnail", config.Thumb) + return payload } // DocumentConfig contains information about a SendDocument request. @@ -584,19 +650,14 @@ func (config DocumentConfig) method() string { } func (config DocumentConfig) files() []RequestFile { - files := []RequestFile{{ - Name: "document", - Data: config.File, - }} - - if config.Thumb != nil { - files = append(files, RequestFile{ - Name: "thumbnail", - Data: config.Thumb, - }) - } + return config.filePayload().filesSlice() +} - return files +func (config DocumentConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("document", config.File) + payload.Add("thumbnail", config.Thumb) + return payload } // StickerConfig contains information about a SendSticker request. @@ -620,10 +681,13 @@ func (config StickerConfig) method() string { } func (config StickerConfig) files() []RequestFile { - return []RequestFile{{ - Name: "sticker", - Data: config.File, - }} + return config.filePayload().filesSlice() +} + +func (config StickerConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("sticker", config.File) + return payload } // VideoConfig contains information about a SendVideo request. @@ -672,25 +736,15 @@ func (config VideoConfig) method() string { } func (config VideoConfig) files() []RequestFile { - files := []RequestFile{{ - Name: "video", - Data: config.File, - }} - - if config.Thumb != nil { - files = append(files, RequestFile{ - Name: "thumbnail", - Data: config.Thumb, - }) - } + return config.filePayload().filesSlice() +} - if config.Cover != nil { - files = append(files, RequestFile{ - Name: "cover", - Data: config.Cover, - }) - } - return files +func (config VideoConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("video", config.File) + payload.Add("thumbnail", config.Thumb) + payload.Add("cover", config.Cover) + return payload } // AnimationConfig contains information about a SendAnimation request. @@ -734,19 +788,14 @@ func (config AnimationConfig) method() string { } func (config AnimationConfig) files() []RequestFile { - files := []RequestFile{{ - Name: "animation", - Data: config.File, - }} - - if config.Thumb != nil { - files = append(files, RequestFile{ - Name: "thumbnail", - Data: config.Thumb, - }) - } + return config.filePayload().filesSlice() +} - return files +func (config AnimationConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("animation", config.File) + payload.Add("thumbnail", config.Thumb) + return payload } // VideoNoteConfig contains information about a SendVideoNote request. @@ -771,19 +820,14 @@ func (config VideoNoteConfig) method() string { } func (config VideoNoteConfig) files() []RequestFile { - files := []RequestFile{{ - Name: "video_note", - Data: config.File, - }} - - if config.Thumb != nil { - files = append(files, RequestFile{ - Name: "thumbnail", - Data: config.Thumb, - }) - } + return config.filePayload().filesSlice() +} - return files +func (config VideoNoteConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("video_note", config.File) + payload.Add("thumbnail", config.Thumb) + return payload } // Use this method to send paid media to channel chats. On success, the sent Message is returned. @@ -808,34 +852,35 @@ func (config PaidMediaConfig) params() (Params, error) { params.AddNonEmpty("parse_mode", config.ParseMode) params.AddBool("show_caption_above_media", config.ShowCaptionAboveMedia) - media := []InputMedia{config.Media} - newMedia := prepareInputMediaForParams(media) - err = params.AddInterface("media", newMedia[0]) + var payload uploadPayload + + if config.Media != nil { + prepared, uploads := prepareInputMedia([]InputMedia{config.Media}) + payload = uploads + err = params.AddInterface("media", prepared[0]) + } else { + err = params.AddInterface("media", nil) + } + if err != nil { return params, err } err = params.AddInterface("caption_entities", config.CaptionEntities) + params = payload.applyInline(params) return params, err } func (config PaidMediaConfig) files() []RequestFile { - files := []RequestFile{} - - if config.Media.getMedia().NeedsUpload() { - files = append(files, RequestFile{ - Name: "file-0", - Data: config.Media.getMedia(), - }) - } + return config.filePayload().filesSlice() +} - if thumb := config.Media.getThumb(); thumb != nil && thumb.NeedsUpload() { - files = append(files, RequestFile{ - Name: "file-0-thumb", - Data: thumb, - }) +func (config PaidMediaConfig) filePayload() uploadPayload { + if config.Media == nil { + return newUploadPayload() } - return files + _, payload := prepareInputMedia([]InputMedia{config.Media}) + return payload } func (config PaidMediaConfig) method() string { @@ -871,19 +916,14 @@ func (config VoiceConfig) method() string { } func (config VoiceConfig) files() []RequestFile { - files := []RequestFile{{ - Name: "voice", - Data: config.File, - }} - - if config.Thumb != nil { - files = append(files, RequestFile{ - Name: "thumbnail", - Data: config.Thumb, - }) - } + return config.filePayload().filesSlice() +} - return files +func (config VoiceConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("voice", config.File) + payload.Add("thumbnail", config.Thumb) + return payload } // LocationConfig contains information about a SendLocation request. @@ -1240,15 +1280,25 @@ func (config EditMessageMediaConfig) params() (Params, error) { return params, err } - preparedMedia := prepareInputMediaForParams([]InputMedia{config.Media}) + preparedMedia, payload := prepareInputMedia([]InputMedia{config.Media}) err = params.AddInterface("media", preparedMedia[0]) + if err != nil { + return params, err + } + + params = payload.applyInline(params) return params, err } func (config EditMessageMediaConfig) files() []RequestFile { - return prepareInputMediaForFiles([]InputMedia{config.Media}) + return config.filePayload().filesSlice() +} + +func (config EditMessageMediaConfig) filePayload() uploadPayload { + _, payload := prepareInputMedia([]InputMedia{config.Media}) + return payload } // EditMessageReplyMarkupConfig allows you to modify the reply markup @@ -1418,14 +1468,13 @@ func (config WebhookConfig) params() (Params, error) { } func (config WebhookConfig) files() []RequestFile { - if config.Certificate != nil { - return []RequestFile{{ - Name: "certificate", - Data: config.Certificate, - }} - } + return config.filePayload().filesSlice() +} - return nil +func (config WebhookConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("certificate", config.Certificate) + return payload } // DeleteWebhookConfig is a helper to delete a webhook. @@ -2517,10 +2566,13 @@ func (config SetChatPhotoConfig) method() string { } func (config SetChatPhotoConfig) files() []RequestFile { - return []RequestFile{{ - Name: "photo", - Data: config.File, - }} + return config.filePayload().filesSlice() +} + +func (config SetChatPhotoConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("photo", config.File) + return payload } // DeleteChatPhotoConfig allows you to delete a group, supergroup, or channel's photo. @@ -2634,7 +2686,13 @@ func (config UploadStickerConfig) params() (Params, error) { } func (config UploadStickerConfig) files() []RequestFile { - return []RequestFile{config.Sticker} + return config.filePayload().filesSlice() +} + +func (config UploadStickerConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add(config.Sticker.Name, config.Sticker.Data) + return payload } // NewStickerSetConfig allows creating a new sticker set. @@ -2666,11 +2724,15 @@ func (config NewStickerSetConfig) params() (Params, error) { } func (config NewStickerSetConfig) files() []RequestFile { - requestFiles := []RequestFile{} - for _, v := range config.Stickers { - requestFiles = append(requestFiles, v.Sticker) + return config.filePayload().filesSlice() +} + +func (config NewStickerSetConfig) filePayload() uploadPayload { + payload := newUploadPayload() + for _, sticker := range config.Stickers { + payload.Add(sticker.Sticker.Name, sticker.Sticker.Data) } - return requestFiles + return payload } // AddStickerConfig allows you to add a sticker to a set. @@ -2694,7 +2756,13 @@ func (config AddStickerConfig) params() (Params, error) { } func (config AddStickerConfig) files() []RequestFile { - return []RequestFile{config.Sticker.Sticker} + return config.filePayload().filesSlice() +} + +func (config AddStickerConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add(config.Sticker.Sticker.Name, config.Sticker.Sticker.Data) + return payload } // SetStickerPositionConfig allows you to change the position of a sticker in a set. @@ -2896,10 +2964,13 @@ func (config SetStickerSetThumbConfig) params() (Params, error) { } func (config SetStickerSetThumbConfig) files() []RequestFile { - return []RequestFile{{ - Name: "thumbnail", - Data: config.Thumb, - }} + return config.filePayload().filesSlice() +} + +func (config SetStickerSetThumbConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("thumbnail", config.Thumb) + return payload } // SetChatStickerSetConfig allows you to set the sticker set for a supergroup. @@ -3139,13 +3210,25 @@ func (config MediaGroupConfig) params() (Params, error) { return nil, err } - err = params.AddInterface("media", prepareInputMediaForParams(config.Media)) + preparedMedia, payload := prepareInputMedia(config.Media) + + err = params.AddInterface("media", preparedMedia) + if err != nil { + return params, err + } + + params = payload.applyInline(params) return params, err } func (config MediaGroupConfig) files() []RequestFile { - return prepareInputMediaForFiles(config.Media) + return config.filePayload().filesSlice() +} + +func (config MediaGroupConfig) filePayload() uploadPayload { + _, payload := prepareInputMedia(config.Media) + return payload } // DiceConfig contains information about a sendDice request. @@ -3487,47 +3570,51 @@ func (config GetMyDefaultAdministratorRightsConfig) params() (Params, error) { return params, nil } -// prepareInputMediaForParams processes media items for API parameters. -// It creates a copy of the media array with files prepared for upload. -func prepareInputMediaForParams(inputMedia []InputMedia) []InputMedia { - newMedias := cloneMediaSlice(inputMedia) - for idx, media := range newMedias { - if media.getMedia().NeedsUpload() { - media.setUploadMedia(fmt.Sprintf("attach://file-%d", idx)) - } +// prepareInputMedia normalizes media payloads and gathers uploadable files. +func prepareInputMedia(inputMedia []InputMedia) ([]InputMedia, uploadPayload) { + prepared := cloneMediaSlice(inputMedia) + payload := newUploadPayload() - if thumb := media.getThumb(); thumb != nil && thumb.NeedsUpload() { - media.setUploadThumb(fmt.Sprintf("attach://file-%d-thumb", idx)) + for idx, media := range prepared { + if media == nil { + continue } - newMedias[idx] = media - } - - return newMedias -} - -// prepareInputMediaForFiles generates RequestFile objects for media items -// that need to be uploaded. -func prepareInputMediaForFiles(inputMedia []InputMedia) []RequestFile { - files := []RequestFile{} - - for idx, media := range inputMedia { - if media.getMedia() != nil && media.getMedia().NeedsUpload() { - files = append(files, RequestFile{ - Name: fmt.Sprintf("file-%d", idx), - Data: media.getMedia(), + fileRef := media.getMedia() + if fileRef != nil && fileRef.NeedsUpload() { + name := fmt.Sprintf("file-%d", idx) + media.setUploadMedia(fmt.Sprintf("attach://%s", name)) + payload.files = append(payload.files, RequestFile{ + Name: name, + Data: fileRef, }) } if thumb := media.getThumb(); thumb != nil && thumb.NeedsUpload() { - files = append(files, RequestFile{ - Name: fmt.Sprintf("file-%d-thumb", idx), + name := fmt.Sprintf("file-%d-thumb", idx) + media.setUploadThumb(fmt.Sprintf("attach://%s", name)) + payload.files = append(payload.files, RequestFile{ + Name: name, Data: thumb, }) } + + prepared[idx] = media } - return files + return prepared, payload +} + +// prepareInputMediaForParams processes media items for API parameters. +func prepareInputMediaForParams(inputMedia []InputMedia) []InputMedia { + prepared, _ := prepareInputMedia(inputMedia) + return prepared +} + +// prepareInputMediaForFiles generates RequestFile objects for media items. +func prepareInputMediaForFiles(inputMedia []InputMedia) []RequestFile { + _, payload := prepareInputMedia(inputMedia) + return payload.filesSlice() } func ptr[T any](v T) *T { diff --git a/docs/internals/uploading-files.md b/docs/internals/uploading-files.md index 1845269c..bb72eebb 100644 --- a/docs/internals/uploading-files.md +++ b/docs/internals/uploading-files.md @@ -19,30 +19,29 @@ file named `photo`. All we have to do is set that single field with the correct value (either a string or multipart file). Methods like `sendDocument` take two file uploads, a `document` and a `thumb`. These are pretty straightforward. -Remembering that the `Fileable` interface only requires one method, let's -implement it for `DocumentConfig`. +Remembering that the `Fileable` interface only requires one method, we expose a +`filePayload` helper that declares the intent for each field. The helper uses a +builder that decides whether an entry becomes an inline parameter or a streamed +upload. ```go +func (config DocumentConfig) filePayload() uploadPayload { + payload := newUploadPayload() + payload.Add("document", config.File) + payload.Add("thumbnail", config.Thumb) + return payload +} + func (config DocumentConfig) files() []RequestFile { - // We can have multiple files, so we'll create an array. We also know that - // there always is a document file, so initialize the array with that. - files := []RequestFile{{ - Name: "document", - Data: config.File, - }} - - // We'll only add a file if we have one. - if config.Thumb != nil { - files = append(files, RequestFile{ - Name: "thumb", - Data: config.Thumb, - }) - } - - return files + return config.filePayload().filesSlice() } ``` +Calling `payload.Add` automatically promotes remote references (for example +`FileID`) into inline params while keeping uploadable variants in the returned +slice. This keeps the configuration declarative and avoids the panic guards +that used to exist on individual file types. + Telegram also supports the `attach://` syntax (discussed more later) for thumbnails, but there's no reason to make things more complicated. @@ -78,10 +77,9 @@ A `MediaGroupConfig` stores all the media in an array of interfaces. We now have all the data we need to upload, but how do we figure out field names for uploads? We didn't specify `attach://unique-file` anywhere. -When the library goes to upload the files, it looks at the `params` and `files` -for the Config. The params are generated by transforming the file into a value -more suitable for uploading, file IDs and URLs are untouched but uploaded types -are all changed into `attach://file-%d`. When collecting a list of files to -upload, it names them the same way. This creates a nearly transparent way of -handling multiple files in the background without the user having to consider -what's going on. +When the library goes to upload the files, it materializes the payload builder +and inspects which entries require streaming. The prepared media is rewritten +to reference `attach://file-%d` slots, and the corresponding uploads are added +to the multipart request under matching names. Remote references stay in the +params map, so the calling code does not have to manually synchronize field +names or worry about transitioning between upload and reference modes. diff --git a/file_source.go b/file_source.go new file mode 100644 index 00000000..dcb136a2 --- /dev/null +++ b/file_source.go @@ -0,0 +1,170 @@ +package tgbotapi + +import ( + "bytes" + "errors" + "io" + "os" +) + +var ( + errFileSourceNoUpload = errors.New("file source does not support uploads") + errFileSourceNoReference = errors.New("file source does not support reference values") +) + +type fileSourceKind int + +const ( + fileSourceUpload fileSourceKind = iota + fileSourceFileID + fileSourceURL + fileSourceAttach + fileSourceInline +) + +type uploadDescriptor struct { + name string + reader io.ReadCloser +} + +type fileSource struct { + kind fileSourceKind + uploadFn func() (uploadDescriptor, error) + referenceFn func() (string, error) +} + +func (s fileSource) kindIsUpload() bool { + return s.kind == fileSourceUpload +} + +func (s fileSource) openUpload() (uploadDescriptor, error) { + if s.uploadFn == nil { + return uploadDescriptor{}, errFileSourceNoUpload + } + + return s.uploadFn() +} + +func (s fileSource) referenceValue() (string, error) { + if s.referenceFn == nil { + return "", errFileSourceNoReference + } + + return s.referenceFn() +} + +func newBytesSource(name string, data []byte) fileSource { + return fileSource{ + kind: fileSourceUpload, + uploadFn: func() (uploadDescriptor, error) { + return uploadDescriptor{ + name: name, + reader: io.NopCloser(bytes.NewReader(data)), + }, nil + }, + } +} + +func newReaderSource(name string, reader io.Reader) fileSource { + return fileSource{ + kind: fileSourceUpload, + uploadFn: func() (uploadDescriptor, error) { + if rc, ok := reader.(io.ReadCloser); ok { + return uploadDescriptor{name: name, reader: rc}, nil + } + + return uploadDescriptor{ + name: name, + reader: io.NopCloser(reader), + }, nil + }, + } +} + +func newPathSource(path string) fileSource { + return fileSource{ + kind: fileSourceUpload, + uploadFn: func() (uploadDescriptor, error) { + handle, err := os.Open(path) + if err != nil { + return uploadDescriptor{}, err + } + + return uploadDescriptor{ + name: handle.Name(), + reader: handle, + }, nil + }, + } +} + +func newURLSource(raw string) fileSource { + return fileSource{ + kind: fileSourceURL, + referenceFn: func() (string, error) { + return raw, nil + }, + } +} + +func newFileIDSource(id string) fileSource { + return fileSource{ + kind: fileSourceFileID, + referenceFn: func() (string, error) { + return id, nil + }, + } +} + +func newAttachSource(value string) fileSource { + return fileSource{ + kind: fileSourceAttach, + referenceFn: func() (string, error) { + return value, nil + }, + } +} + +type fileSourceProvider interface { + descriptor() fileSource +} + +func resolveRequestFileData(data RequestFileData) (fileSource, error) { + if provider, ok := data.(fileSourceProvider); ok { + return provider.descriptor(), nil + } + + if data == nil { + return fileSource{}, errors.New("file data is nil") + } + + if data.NeedsUpload() { + return fileSource{ + kind: fileSourceUpload, + uploadFn: func() (uploadDescriptor, error) { + name, reader, err := data.UploadData() + if err != nil { + return uploadDescriptor{}, err + } + + if rc, ok := reader.(io.ReadCloser); ok { + return uploadDescriptor{name: name, reader: rc}, nil + } + + return uploadDescriptor{ + name: name, + reader: io.NopCloser(reader), + }, nil + }, + }, nil + } + + value := data.SendData() + + return fileSource{ + kind: fileSourceInline, + referenceFn: func() (string, error) { + return value, nil + }, + }, nil +} diff --git a/file_source_test.go b/file_source_test.go new file mode 100644 index 00000000..89ff26d5 --- /dev/null +++ b/file_source_test.go @@ -0,0 +1,108 @@ +package tgbotapi + +import ( + "bytes" + "io" + "os" + "testing" +) + +func TestRequestFileDataSources(t *testing.T) { + tmp, err := os.CreateTemp(t.TempDir(), "upload-*.txt") + if err != nil { + t.Fatalf("create temp file: %v", err) + } + + _, err = tmp.WriteString("temp-data") + if err != nil { + t.Fatalf("write temp file: %v", err) + } + + err = tmp.Close() + if err != nil { + t.Fatalf("close temp file: %v", err) + } + + cases := []struct { + name string + data RequestFileData + expectUpload bool + expectValue string + }{ + { + name: "bytes upload", + data: FileBytes{Name: "data.bin", Bytes: []byte("content")}, + expectUpload: true, + }, + { + name: "reader upload", + data: FileReader{Name: "reader.dat", Reader: bytes.NewBufferString("stream")}, + expectUpload: true, + }, + { + name: "path upload", + data: FilePath(tmp.Name()), + expectUpload: true, + }, + { + name: "remote url", + data: FileURL("https://example.com/demo"), + expectValue: "https://example.com/demo", + }, + { + name: "file id", + data: FileID("ABC123"), + expectValue: "ABC123", + }, + { + name: "attach value", + data: fileAttach("attach://demo"), + expectValue: "attach://demo", + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + if tc.data.NeedsUpload() != tc.expectUpload { + t.Fatalf("unexpected upload flag: %v", tc.data.NeedsUpload()) + } + + if tc.expectUpload { + name, reader, err := tc.data.UploadData() + if err != nil { + t.Fatalf("upload error: %v", err) + } + + if name == "" { + t.Fatalf("expected upload name") + } + + all, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("read upload: %v", err) + } + + if closer, ok := reader.(io.Closer); ok { + if err := closer.Close(); err != nil { + t.Fatalf("close upload reader: %v", err) + } + } + + if len(all) == 0 { + t.Fatalf("expected upload payload") + } + } else { + if _, _, err := tc.data.UploadData(); err == nil { + t.Fatalf("expected upload error") + } + + value := tc.data.SendData() + if value != tc.expectValue { + t.Fatalf("unexpected value: %q", value) + } + } + }) + } +} diff --git a/upload_payload.go b/upload_payload.go new file mode 100644 index 00000000..e33c7992 --- /dev/null +++ b/upload_payload.go @@ -0,0 +1,98 @@ +package tgbotapi + +type uploadPayload struct { + files []RequestFile + inline map[string]string +} + +func newUploadPayload() uploadPayload { + return uploadPayload{ + inline: map[string]string{}, + } +} + +func (p *uploadPayload) Add(field string, data RequestFileData) { + if data == nil { + return + } + + source, err := resolveRequestFileData(data) + if err != nil { + return + } + + if source.kindIsUpload() { + p.files = append(p.files, RequestFile{ + Name: field, + Data: data, + }) + return + } + + value, err := source.referenceValue() + if err != nil { + return + } + + if p.inline == nil { + p.inline = map[string]string{} + } + + p.inline[field] = value +} + +func (p *uploadPayload) AddUploadOnly(field string, data RequestFileData) { + if data == nil { + return + } + + source, err := resolveRequestFileData(data) + if err != nil { + return + } + + if source.kindIsUpload() { + p.files = append(p.files, RequestFile{ + Name: field, + Data: data, + }) + } +} + +func (p uploadPayload) needsUpload() bool { + return len(p.files) > 0 +} + +func (p uploadPayload) filesSlice() []RequestFile { + return p.files +} + +func (p uploadPayload) applyInline(params Params) Params { + if len(p.inline) == 0 { + return params + } + + if params == nil { + params = Params{} + } + + for key, value := range p.inline { + params[key] = value + } + + return params +} + +func payloadFromFileable(f Fileable) uploadPayload { + if provider, ok := f.(interface{ filePayload() uploadPayload }); ok { + return provider.filePayload() + } + + payload := newUploadPayload() + + for _, file := range f.files() { + payload.Add(file.Name, file.Data) + } + + return payload +} diff --git a/upload_payload_test.go b/upload_payload_test.go new file mode 100644 index 00000000..7e4d36b0 --- /dev/null +++ b/upload_payload_test.go @@ -0,0 +1,177 @@ +package tgbotapi + +import ( + "bytes" + "errors" + "io" + "mime" + "mime/multipart" + "testing" +) + +func TestUploadPayloadBuilder(t *testing.T) { + payload := newUploadPayload() + + payload.Add("photo", FileBytes{Name: "pic.jpg", Bytes: []byte("data")}) + payload.Add("thumb", FileID("file-id")) + payload.AddUploadOnly("skip", FileID("unused")) + + if !payload.needsUpload() { + t.Fatalf("expected upload payload to require upload") + } + + files := payload.filesSlice() + if len(files) != 1 { + t.Fatalf("expected single upload file, got %d", len(files)) + } + + if files[0].Name != "photo" { + t.Fatalf("unexpected upload field %q", files[0].Name) + } + + params := payload.applyInline(nil) + if params["thumb"] != "file-id" { + t.Fatalf("expected inline thumb value, got %q", params["thumb"]) + } + + if _, ok := params["skip"]; ok { + t.Fatalf("did not expect upload-only field in inline params") + } +} + +func TestPrepareInputMedia(t *testing.T) { + photo := NewInputMediaPhoto(FileBytes{Name: "image.png", Bytes: []byte("media")}) + video := NewInputMediaVideo(FileBytes{Name: "video.mp4", Bytes: []byte("clip")}) + video.Thumb = FileBytes{Name: "thumb.jpg", Bytes: []byte("thumb")} + + prepared, payload := prepareInputMedia([]InputMedia{&photo, &video}) + + if prepared[0].getMedia().SendData() != "attach://file-0" { + t.Fatalf("unexpected media ref: %q", prepared[0].getMedia().SendData()) + } + + if prepared[1].getMedia().SendData() != "attach://file-1" { + t.Fatalf("unexpected video ref: %q", prepared[1].getMedia().SendData()) + } + + if prepared[1].getThumb().SendData() != "attach://file-1-thumb" { + t.Fatalf("unexpected thumb ref: %q", prepared[1].getThumb().SendData()) + } + + files := payload.filesSlice() + if len(files) != 3 { + t.Fatalf("expected 3 upload parts, got %d", len(files)) + } + + expectedNames := map[string]struct{}{ + "file-0": {}, + "file-1": {}, + "file-1-thumb": {}, + } + + for _, f := range files { + if _, ok := expectedNames[f.Name]; !ok { + t.Fatalf("unexpected upload name %q", f.Name) + } + delete(expectedNames, f.Name) + } + + if len(expectedNames) != 0 { + t.Fatalf("missing upload fields: %v", expectedNames) + } +} + +func TestBuildMultipartPayload(t *testing.T) { + params := Params{ + "text": "hello", + } + + files := []RequestFile{ + { + Name: "photo", + Data: FileBytes{Name: "img.jpg", Bytes: []byte("jpeg-data")}, + }, + { + Name: "thumbnail", + Data: FileID("remote-thumb"), + }, + } + + payload, err := buildMultipartPayload(params, files) + if err != nil { + t.Fatalf("build multipart payload: %v", err) + } + + mediaType, attrs, err := mime.ParseMediaType(payload.contentType) + if err != nil { + t.Fatalf("parse media type: %v", err) + } + + if mediaType != "multipart/form-data" { + t.Fatalf("unexpected media type %q", mediaType) + } + + boundary := attrs["boundary"] + if boundary == "" { + t.Fatalf("missing boundary") + } + + rawBody, err := io.ReadAll(payload.body) + if err != nil { + t.Fatalf("read payload body: %v", err) + } + + if len(rawBody) == 0 { + t.Fatalf("empty multipart payload") + } + + reader := multipart.NewReader(bytes.NewReader(rawBody), boundary) + + var ( + foundText bool + foundThumb bool + foundUpload bool + uploadContents []byte + ) + + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("read multipart: %v", err) + } + + data, err := io.ReadAll(part) + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { + t.Fatalf("read part: %v", err) + } + + switch part.FormName() { + case "text": + foundText = bytes.Equal(data, []byte("hello")) + case "thumbnail": + foundThumb = bytes.Equal(data, []byte("remote-thumb")) + case "photo": + foundUpload = part.FileName() == "img.jpg" + uploadContents = append([]byte(nil), data...) + } + } + + if !foundText { + t.Fatalf("missing form field value") + } + + if !foundThumb { + t.Fatalf("missing inline thumb value") + } + + if !foundUpload { + t.Fatalf("missing upload part") + } + + if !bytes.Equal(uploadContents, []byte("jpeg-data")) { + t.Fatalf("unexpected upload payload %q", string(uploadContents)) + } +}