Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 78 additions & 23 deletions Sources/HuggingFace/Hub/HubClient+Datasets.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,22 @@ extension HubClient {
revision: String? = nil,
full: Bool? = nil
) async throws -> Dataset {
let path: String
var url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
if let revision {
path = "/api/datasets/\(id.namespace)/\(id.name)/revision/\(revision)"
} else {
path = "/api/datasets/\(id.namespace)/\(id.name)"
url =
url
.appending(path: "revision")
.appending(component: revision)
}

var params: [String: Value] = [:]
if let full { params["full"] = .bool(full) }

return try await httpClient.fetch(.get, path, params: params)
return try await httpClient.fetch(.get, url: url, params: params)
}

/// Gets all available dataset tags hosted in the Hub.
Expand Down Expand Up @@ -128,8 +133,14 @@ extension HubClient {
/// - Returns: `true` if the request was cancelled successfully.
/// - Throws: An error if the request fails.
public func cancelDatasetAccessRequest(_ id: Repo.ID) async throws -> Bool {
let path = "/api/datasets/\(id.namespace)/\(id.name)/user-access-request/cancel"
let result: Bool = try await httpClient.fetch(.post, path)
let url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "user-access-request")
.appending(path: "cancel")
let result: Bool = try await httpClient.fetch(.post, url: url)
return result
}

Expand All @@ -139,8 +150,14 @@ extension HubClient {
/// - Returns: `true` if access was granted successfully.
/// - Throws: An error if the request fails.
public func grantDatasetAccess(_ id: Repo.ID) async throws -> Bool {
let path = "/api/datasets/\(id.namespace)/\(id.name)/user-access-request/grant"
let result: Bool = try await httpClient.fetch(.post, path)
let url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "user-access-request")
.appending(path: "grant")
let result: Bool = try await httpClient.fetch(.post, url: url)
return result
}

Expand All @@ -150,8 +167,14 @@ extension HubClient {
/// - Returns: `true` if the request was handled successfully.
/// - Throws: An error if the request fails.
public func handleDatasetAccessRequest(_ id: Repo.ID) async throws -> Bool {
let path = "/api/datasets/\(id.namespace)/\(id.name)/user-access-request/handle"
let result: Bool = try await httpClient.fetch(.post, path)
let url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "user-access-request")
.appending(path: "handle")
let result: Bool = try await httpClient.fetch(.post, url: url)
return result
}

Expand All @@ -166,8 +189,14 @@ extension HubClient {
_ id: Repo.ID,
status: AccessRequest.Status
) async throws -> [AccessRequest] {
let path = "/api/datasets/\(id.namespace)/\(id.name)/user-access-request/\(status.rawValue)"
return try await httpClient.fetch(.get, path)
let url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "user-access-request")
.appending(path: status.rawValue)
return try await httpClient.fetch(.get, url: url)
}

/// Gets user access report for a dataset repository.
Expand All @@ -176,8 +205,12 @@ extension HubClient {
/// - Returns: User access report data.
/// - Throws: An error if the request fails.
public func getDatasetUserAccessReport(_ id: Repo.ID) async throws -> Data {
let path = "/datasets/\(id.namespace)/\(id.name)/user-access-report"
return try await httpClient.fetchData(.get, path)
let url = httpClient.host
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "user-access-report")
return try await httpClient.fetchData(.get, url: url)
}

// MARK: - Dataset Advanced Features
Expand All @@ -193,13 +226,18 @@ extension HubClient {
_ id: Repo.ID,
resourceGroupId: String?
) async throws -> ResourceGroup {
let path = "/api/datasets/\(id.namespace)/\(id.name)/resource-group"
let url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "resource-group")

let params: [String: Value] = [
"resourceGroupId": resourceGroupId.map { .string($0) } ?? .null
]

return try await httpClient.fetch(.post, path, params: params)
return try await httpClient.fetch(.post, url: url, params: params)
}

/// Scans a dataset repository.
Expand All @@ -208,8 +246,13 @@ extension HubClient {
/// - Returns: `true` if the scan was initiated successfully.
/// - Throws: An error if the request fails.
public func scanDataset(_ id: Repo.ID) async throws -> Bool {
let path = "/api/datasets/\(id.namespace)/\(id.name)/scan"
let result: Bool = try await httpClient.fetch(.post, path)
let url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "scan")
let result: Bool = try await httpClient.fetch(.post, url: url)
return result
}

Expand All @@ -228,14 +271,20 @@ extension HubClient {
tag: String,
message: String? = nil
) async throws -> Bool {
let path = "/api/datasets/\(id.namespace)/\(id.name)/tag/\(revision)"
let url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "tag")
.appending(component: revision)

let params: [String: Value] = [
"tag": .string(tag),
"message": message.map { .string($0) } ?? .null,
]

let result: Bool = try await httpClient.fetch(.post, path, params: params)
let result: Bool = try await httpClient.fetch(.post, url: url, params: params)
return result
}

Expand All @@ -252,14 +301,20 @@ extension HubClient {
revision: String,
message: String
) async throws -> String {
let path = "/api/datasets/\(id.namespace)/\(id.name)/super-squash/\(revision)"
let url = httpClient.host
.appending(path: "api")
.appending(path: "datasets")
.appending(path: id.namespace)
.appending(path: id.name)
.appending(path: "super-squash")
.appending(component: revision)

let params: [String: Value] = [
"message": .string(message)
]

struct Response: Decodable { let commitID: String }
let resp: Response = try await httpClient.fetch(.post, path, params: params)
let resp: Response = try await httpClient.fetch(.post, url: url, params: params)
return resp.commitID
}
}
57 changes: 45 additions & 12 deletions Sources/HuggingFace/Hub/HubClient+Files.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ public extension HubClient {
branch: String = "main",
message: String? = nil
) async throws -> (path: String, commit: String?) {
let urlPath = "/api/\(kind.pluralized)/\(repo)/upload/\(branch)"
var request = try await httpClient.createRequest(.post, urlPath)
let url = httpClient.host
.appending(path: "api")
.appending(path: kind.pluralized)
.appending(path: repo.namespace)
.appending(path: repo.name)
.appending(path: "upload")
.appending(component: branch)
var request = try await httpClient.createRequest(.post, url: url)

let boundary = "----hf-\(UUID().uuidString)"
request.setValue(
Expand Down Expand Up @@ -195,8 +201,13 @@ public extension HubClient {
}

let endpoint = useRaw ? "raw" : "resolve"
let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)"
var request = try await httpClient.createRequest(.get, urlPath)
let url = httpClient.host
.appending(path: repo.namespace)
.appending(path: repo.name)
.appending(path: endpoint)
.appending(component: revision)
.appending(path: repoPath)
var request = try await httpClient.createRequest(.get, url: url)
request.cachePolicy = cachePolicy

let (data, response) = try await session.data(for: request)
Expand Down Expand Up @@ -265,8 +276,13 @@ public extension HubClient {
}

let endpoint = useRaw ? "raw" : "resolve"
let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)"
var request = try await httpClient.createRequest(.get, urlPath)
let url = httpClient.host
.appending(path: repo.namespace)
.appending(path: repo.name)
.appending(path: endpoint)
.appending(component: revision)
.appending(path: repoPath)
var request = try await httpClient.createRequest(.get, url: url)
request.cachePolicy = cachePolicy

let (tempURL, response) = try await session.download(
Expand Down Expand Up @@ -424,7 +440,13 @@ public extension HubClient {
branch: String = "main",
message: String
) async throws {
let urlPath = "/api/\(kind.pluralized)/\(repo)/commit/\(branch)"
let url = httpClient.host
.appending(path: "api")
.appending(path: kind.pluralized)
.appending(path: repo.namespace)
.appending(path: repo.name)
.appending(path: "commit")
.appending(component: branch)
let operations = repoPaths.map { path in
Value.object(["op": .string("delete"), "path": .string(path)])
}
Expand All @@ -433,7 +455,7 @@ public extension HubClient {
"operations": .array(operations),
]

let _: Bool = try await httpClient.fetch(.post, urlPath, params: params)
let _: Bool = try await httpClient.fetch(.post, url: url, params: params)
}
}

Expand Down Expand Up @@ -474,10 +496,16 @@ public extension HubClient {
revision: String = "main",
recursive: Bool = true
) async throws -> [Git.TreeEntry] {
let urlPath = "/api/\(kind.pluralized)/\(repo)/tree/\(revision)"
let url = httpClient.host
.appending(path: "api")
.appending(path: kind.pluralized)
.appending(path: repo.namespace)
.appending(path: repo.name)
.appending(path: "tree")
.appending(component: revision)
let params: [String: Value]? = recursive ? ["recursive": .bool(true)] : nil

return try await httpClient.fetch(.get, urlPath, params: params)
return try await httpClient.fetch(.get, url: url, params: params)
}

/// Get file information
Expand All @@ -493,8 +521,13 @@ public extension HubClient {
kind _: Repo.Kind = .model,
revision: String = "main"
) async throws -> File {
let urlPath = "/\(repo)/resolve/\(revision)/\(repoPath)"
var request = try await httpClient.createRequest(.head, urlPath)
let url = httpClient.host
.appending(path: repo.namespace)
.appending(path: repo.name)
.appending(path: "resolve")
.appending(component: revision)
.appending(path: repoPath)
var request = try await httpClient.createRequest(.head, url: url)
request.setValue("bytes=0-0", forHTTPHeaderField: "Range")

do {
Expand Down
Loading