Async http client for pull/push (#95)

* Use async http client for pull/push

* Don't update progress too frequently

* Removed unused variable

* Rebased after added tests
This commit is contained in:
Fedor Korotkov 2022-05-20 11:04:21 -04:00 committed by GitHub
parent fec803277d
commit 35904dc637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 197 additions and 75 deletions

View File

@ -1,5 +1,14 @@
{
"pins" : [
{
"identity" : "async-http-client",
"kind" : "remoteSourceControl",
"location" : "https://github.com/swift-server/async-http-client",
"state" : {
"revision" : "24425989dadab6d6e4167174791a23d4e2a6d0c3",
"version" : "1.10.0"
}
},
{
"identity" : "dynamic",
"kind" : "remoteSourceControl",
@ -27,6 +36,60 @@
"version" : "0.8.1"
}
},
{
"identity" : "swift-log",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-log.git",
"state" : {
"revision" : "5d66f7ba25daf4f94100e7022febf3c75e37a6c7",
"version" : "1.4.2"
}
},
{
"identity" : "swift-nio",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio.git",
"state" : {
"revision" : "124119f0bb12384cef35aa041d7c3a686108722d",
"version" : "2.40.0"
}
},
{
"identity" : "swift-nio-extras",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-extras.git",
"state" : {
"revision" : "8eea84ec6144167354387ef9244b0939f5852dc8",
"version" : "1.11.0"
}
},
{
"identity" : "swift-nio-http2",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-http2.git",
"state" : {
"revision" : "72bcaf607b40d7c51044f65b0f5ed8581a911832",
"version" : "1.21.0"
}
},
{
"identity" : "swift-nio-ssl",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-ssl.git",
"state" : {
"revision" : "1750873bce84b4129b5303655cce2c3d35b9ed3a",
"version" : "2.19.0"
}
},
{
"identity" : "swift-nio-transport-services",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-nio-transport-services.git",
"state" : {
"revision" : "1a4692acb88156e3da1b0c6732a8a38b2a744166",
"version" : "1.12.0"
}
},
{
"identity" : "swift-parsing",
"kind" : "remoteSourceControl",

View File

@ -1,7 +1,6 @@
// swift-tools-version:5.6
import PackageDescription
let package = Package(
name: "Tart",
platforms: [
@ -12,15 +11,18 @@ let package = Package(
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.1.2"),
.package(url: "https://github.com/pointfreeco/swift-parsing", from: "0.9.2"),
.package(url: "https://github.com/mhdhejazi/Dynamic", branch: "master"),
.package(url: "https://github.com/pointfreeco/swift-parsing", from: "0.9.2"),
.package(url: "https://github.com/swift-server/async-http-client", from: "1.10.0"),
],
targets: [
.executableTarget(name: "tart", dependencies: [
.product(name: "ArgumentParser", package: "swift-argument-parser"),
.product(name: "AsyncHTTPClient", package: "async-http-client"),
.product(name: "Dynamic", package: "Dynamic"),
.product(name: "Parsing", package: "swift-parsing"),
]),
.testTarget(name: "TartTests", dependencies: ["tart"])
]
)

View File

@ -3,6 +3,7 @@ import Foundation
public class ProgressObserver: NSObject {
@objc var progressToObserve: Progress
var observation: NSKeyValueObservation?
var lastTimeUpdated = Date.now
public init(_ progress: Progress) {
progressToObserve = progress
@ -11,7 +12,11 @@ public class ProgressObserver: NSObject {
func log(_ renderer: Logger) {
renderer.appendNewLine(ProgressObserver.lineToRender(progressToObserve))
observation = observe(\.progressToObserve.fractionCompleted) { progress, _ in
renderer.updateLastLine(ProgressObserver.lineToRender(self.progressToObserve))
let currentTime = Date.now
if self.progressToObserve.isFinished || currentTime.timeIntervalSince(self.lastTimeUpdated) >= 1.0 {
self.lastTimeUpdated = currentTime
renderer.updateLastLine(ProgressObserver.lineToRender(self.progressToObserve))
}
}
}

View File

@ -1,12 +1,30 @@
import Foundation
import NIOCore
import NIOHTTP1
import AsyncHTTPClient
enum RegistryError: Error {
case UnexpectedHTTPStatusCode(when: String, code: Int, details: String = "")
case UnexpectedHTTPStatusCode(when: String, code: UInt, details: String = "")
case MissingLocationHeader
case AuthFailed(why: String, details: String = "")
case MalformedHeader(why: String)
}
extension HTTPClientResponse.Body {
func readTextResponse() async throws -> String? {
let data = try await readResponse()
return String(decoding: data, as: UTF8.self)
}
func readResponse() async throws -> Data {
var result = Data()
for try await part in self {
result.append(Data(buffer: part))
}
return result
}
}
struct TokenResponse: Decodable {
let defaultIssuedAt = Date()
let defaultExpiresIn = 60
@ -46,7 +64,7 @@ struct TokenResponse: Decodable {
(issuedAt ?? defaultIssuedAt) + TimeInterval(expiresIn ?? defaultExpiresIn)
}
}
var isValid: Bool {
get {
Date() < tokenExpiresAt
@ -54,6 +72,8 @@ struct TokenResponse: Decodable {
}
}
fileprivate let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)
class Registry {
var baseURL: URL
var namespace: String
@ -76,41 +96,43 @@ class Registry {
}
func ping() async throws {
let (_, response) = try await endpointRequest("GET", "/v2/")
if response.statusCode != 200 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "doing ping", code: response.statusCode)
let response = try await endpointRequest(.GET, "/v2/")
if response.status != .ok {
throw RegistryError.UnexpectedHTTPStatusCode(when: "doing ping", code: response.status.code)
}
}
func pushManifest(reference: String, manifest: OCIManifest) async throws -> String {
let manifestJSON = try JSONEncoder().encode(manifest)
let (responseData, response) = try await endpointRequest("PUT", "\(namespace)/manifests/\(reference)",
let response = try await endpointRequest(.PUT, "\(namespace)/manifests/\(reference)",
headers: ["Content-Type": manifest.mediaType],
body: manifestJSON)
if response.statusCode != 201 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing manifest", code: response.statusCode,
details: String(decoding: responseData, as: UTF8.self))
if response.status != .created {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing manifest", code: response.status.code,
details: try await response.body.readTextResponse() ?? "")
}
return Digest.hash(manifestJSON)
}
public func pullManifest(reference: String) async throws -> (OCIManifest, Data) {
let (responseData, response) = try await endpointRequest("GET", "\(namespace)/manifests/\(reference)",
let response = try await endpointRequest(.GET, "\(namespace)/manifests/\(reference)",
headers: ["Accept": ociManifestMediaType])
if response.statusCode != 200 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling manifest", code: response.statusCode,
details: String(decoding: responseData, as: UTF8.self))
if response.status != .ok {
let body = try await response.body.readTextResponse()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling manifest", code: response.status.code,
details: body ?? "")
}
let manifest = try JSONDecoder().decode(OCIManifest.self, from: responseData)
let manifestData = try await response.body.readResponse()
let manifest = try JSONDecoder().decode(OCIManifest.self, from: manifestData)
return (manifest, responseData)
return (manifest, manifestData)
}
private func uploadLocationFromResponse(response: HTTPURLResponse) throws -> URLComponents {
guard let uploadLocationRaw = response.value(forHTTPHeaderField: "Location") else {
private func uploadLocationFromResponse(_ response: HTTPClientResponse) throws -> URLComponents {
guard let uploadLocationRaw = response.headers.first(name: "Location") else {
throw RegistryError.MissingLocationHeader
}
@ -123,15 +145,16 @@ class Registry {
public func pushBlob(fromData: Data, chunkSize: Int = 5 * 1024 * 1024) async throws -> String {
// Initiate a blob upload
let (postData, postResponse) = try await endpointRequest("POST", "\(namespace)/blobs/uploads/",
let postResponse = try await endpointRequest(.POST, "\(namespace)/blobs/uploads/",
headers: ["Content-Length": "0"])
if postResponse.statusCode != 202 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (POST)", code: postResponse.statusCode,
details: String(decoding: postData, as: UTF8.self))
if postResponse.status != .accepted {
let body = try await postResponse.body.readTextResponse()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (POST)", code: postResponse.status.code,
details: body ?? "")
}
// Figure out where to upload the blob
let uploadLocation = try uploadLocationFromResponse(response: postResponse)
let uploadLocation = try uploadLocationFromResponse(postResponse)
// Upload the blob
let headers = [
@ -144,33 +167,37 @@ class Registry {
"digest": digest,
]
let (putData, putResponse) = try await rawRequest("PUT", uploadLocation, headers: headers, parameters: parameters,
let putResponse = try await rawRequest(.PUT, uploadLocation, headers: headers, parameters: parameters,
body: fromData)
if putResponse.statusCode != 201 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (PUT)", code: putResponse.statusCode,
details: String(decoding: putData, as: UTF8.self))
if putResponse.status != .created {
let body = try await postResponse.body.readTextResponse()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (PUT)", code: putResponse.status.code,
details: body ?? "")
}
return digest
}
public func pullBlob(_ digest: String) async throws -> Data {
let (putData, putResponse) = try await endpointRequest("GET", "\(namespace)/blobs/\(digest)")
if putResponse.statusCode != 200 {
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling blob", code: putResponse.statusCode,
details: String(decoding: putData, as: UTF8.self))
public func pullBlob(_ digest: String, handler: (ByteBuffer) throws -> Void) async throws {
let response = try await endpointRequest(.GET, "\(namespace)/blobs/\(digest)")
if response.status != .ok {
let body = try await response.body.readTextResponse()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling blob", code: response.status.code,
details: body ?? "")
}
return putData
for try await part in response.body {
try handler(part)
}
}
private func endpointRequest(
_ method: String,
_ method: HTTPMethod,
_ endpoint: String,
headers: Dictionary<String, String> = Dictionary(),
parameters: Dictionary<String, String> = Dictionary(),
body: Data? = nil
) async throws -> (Data, HTTPURLResponse) {
) async throws -> HTTPClientResponse {
let url = URL(string: endpoint, relativeTo: baseURL)!
let urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: true)!
@ -178,12 +205,12 @@ class Registry {
}
private func rawRequest(
_ method: String,
_ method: HTTPMethod,
_ urlComponents: URLComponents,
headers: Dictionary<String, String> = Dictionary(),
parameters: Dictionary<String, String> = Dictionary(),
body: Data? = nil
) async throws -> (Data, HTTPURLResponse) {
) async throws -> HTTPClientResponse {
var urlComponents = urlComponents
if urlComponents.queryItems == nil {
@ -193,31 +220,33 @@ class Registry {
URLQueryItem(name: key, value: value)
})
var request = URLRequest(url: urlComponents.url!)
request.httpMethod = method
var request = HTTPClientRequest(url: urlComponents.string!)
request.method = method
for (key, value) in headers {
request.addValue(value, forHTTPHeaderField: key)
request.headers.add(name: key, value: value)
}
if body != nil {
request.body = HTTPClientRequest.Body.bytes(body!)
}
request.httpBody = body
// Invalidate token if it has expired
if currentAuthToken?.isValid == false {
currentAuthToken = nil
}
var (data, response) = try await authAwareRequest(request: request)
var response = try await authAwareRequest(request: request)
if response.statusCode == 401 {
if response.status == .unauthorized {
try await auth(response: response)
(data, response) = try await authAwareRequest(request: request)
response = try await authAwareRequest(request: request)
}
return (data, response)
return response
}
private func auth(response: HTTPURLResponse) async throws {
private func auth(response: HTTPClientResponse) async throws {
// Process WWW-Authenticate header
guard let wwwAuthenticateRaw = response.value(forHTTPHeaderField: "WWW-Authenticate") else {
guard let wwwAuthenticateRaw = response.headers.first(name: "WWW-Authenticate") else {
throw RegistryError.AuthFailed(why: "got HTTP 401, but WWW-Authenticate header is missing")
}
@ -257,24 +286,24 @@ class Registry {
headers["Authorization"] = "Basic \(encodedCredentials!)"
}
let (tokenResponseRaw, response) = try await rawRequest("GET", authenticateURL, headers: headers)
if response.statusCode != 200 {
throw RegistryError.AuthFailed(why: "received unexpected HTTP status code \(response.statusCode) "
+ "while retrieving an authentication token", details: String(decoding: tokenResponseRaw, as: UTF8.self))
let response = try await rawRequest(.GET, authenticateURL, headers: headers)
if response.status != .ok {
let body = try await response.body.readTextResponse() ?? ""
throw RegistryError.AuthFailed(why: "received unexpected HTTP status code \(response.status.code) "
+ "while retrieving an authentication token", details: body)
}
currentAuthToken = try TokenResponse.parse(fromData: tokenResponseRaw)
let bodyData = try await response.body.readResponse()
currentAuthToken = try TokenResponse.parse(fromData: bodyData)
}
private func authAwareRequest(request: URLRequest) async throws -> (Data, HTTPURLResponse) {
private func authAwareRequest(request: HTTPClientRequest) async throws -> HTTPClientResponse {
var request = request
if let token = currentAuthToken {
request.addValue("Bearer \(token.token)", forHTTPHeaderField: "Authorization")
request.headers.add(name: "Authorization", value: "Bearer \(token.token)")
}
let (responseData, response) = try await URLSession.shared.data(for: request)
return (responseData, response as! HTTPURLResponse)
return try await httpClient.execute(request, deadline: .distantFuture)
}
}

View File

@ -4,7 +4,7 @@ import Compression
enum OCIError: Error {
case ShouldBeExactlyOneLayer
case ShouldBeAtLeastOneLayer
case FailedToCreateDiskFile
case FailedToCreateVmFile
}
extension VMDirectory {
@ -31,8 +31,14 @@ extension VMDirectory {
if configLayers.count != 1 {
throw OCIError.ShouldBeExactlyOneLayer
}
let configData = try await registry.pullBlob(configLayers.first!.digest)
try VMConfig(fromData: configData).save(toURL: configURL)
if !FileManager.default.createFile(atPath: configURL.path, contents: nil) {
throw OCIError.FailedToCreateVmFile
}
let configFile = try FileHandle(forWritingTo: configURL)
try await registry.pullBlob(configLayers.first!.digest) { buffer in
configFile.write(Data(buffer: buffer))
}
try configFile.close()
// Pull VM's disk layers and decompress them sequentially into a disk file
let diskLayers = manifest.layers.filter {
@ -42,7 +48,7 @@ extension VMDirectory {
throw OCIError.ShouldBeAtLeastOneLayer
}
if !FileManager.default.createFile(atPath: diskURL.path, contents: nil) {
throw OCIError.FailedToCreateDiskFile
throw OCIError.FailedToCreateVmFile
}
let disk = try FileHandle(forWritingTo: diskURL)
let filter = try OutputFilter(.decompress, using: .lz4, bufferCapacity: Self.bufferSizeBytes) { data in
@ -52,18 +58,23 @@ extension VMDirectory {
}
// Progress
let diskCompressedSize: Int64 = Int64(diskLayers.map {$0.size}.reduce(0) {$0 + $1})
let diskCompressedSize: Int64 = Int64(diskLayers.map {
$0.size
}
.reduce(0) {
$0 + $1
})
let prettyDiskSize = String(format: "%.1f", Double(diskCompressedSize) / 1_000_000_000.0)
defaultLogger.appendNewLine("pulling disk (\(prettyDiskSize) GB compressed)...")
let progress = Progress(totalUnitCount: diskCompressedSize)
ProgressObserver(progress).log(defaultLogger)
for diskLayer in diskLayers {
let diskData = try await registry.pullBlob(diskLayer.digest)
try filter.write(diskData)
// Progress
progress.completedUnitCount += Int64(diskLayer.size)
try await registry.pullBlob(diskLayer.digest) { buffer in
let data = Data(buffer: buffer)
try filter.write(data)
progress.completedUnitCount += Int64(data.count)
}
}
try filter.finalize()
try disk.close()
@ -77,8 +88,14 @@ extension VMDirectory {
if nvramLayers.count != 1 {
throw OCIError.ShouldBeExactlyOneLayer
}
let nvramData = try await registry.pullBlob(nvramLayers.first!.digest)
try nvramData.write(to: nvramURL)
if !FileManager.default.createFile(atPath: nvramURL.path, contents: nil) {
throw OCIError.FailedToCreateVmFile
}
let nvram = try FileHandle(forWritingTo: nvramURL)
try await registry.pullBlob(nvramLayers.first!.digest) { buffer in
nvram.write(Data(buffer: buffer))
}
try nvram.close()
}
func pushToRegistry(registry: Registry, references: [String]) async throws {
@ -92,7 +109,7 @@ extension VMDirectory {
// Progress
let diskSize = try FileManager.default.attributesOfItem(atPath: diskURL.path)[.size] as! Int64
defaultLogger.appendNewLine("pushing disk... this will take a while...")
let progress = Progress(totalUnitCount: diskSize)
ProgressObserver(progress).log(defaultLogger)
@ -102,7 +119,7 @@ extension VMDirectory {
let disk = try FileHandle(forReadingFrom: diskURL)
let compressingFilter = try InputFilter<Data>(.compress, using: .lz4, bufferCapacity: Self.bufferSizeBytes) { _ in
let data = try disk.read(upToCount: Self.bufferSizeBytes)
progress.completedUnitCount += Int64(data?.count ?? 0)
return data

View File

@ -33,7 +33,10 @@ final class RegistryTests: XCTestCase {
XCTAssertEqual("sha256:d7a8fbb307d7809469ca9abcb0082e4f8d5651e46d3cdb762d02d0bf37c9e592", pushedBlobDigest)
// Pull it
let pulledBlob = try await registry.pullBlob(pushedBlobDigest)
var pulledBlob = Data()
try await registry.pullBlob(pushedBlobDigest) { buffer in
pulledBlob.append(Data(buffer: buffer))
}
// Ensure that both blobs are identical
XCTAssertEqual(pushedBlob, pulledBlob)
@ -48,7 +51,10 @@ final class RegistryTests: XCTestCase {
let largeBlobDigest = try await registry.pushBlob(fromData: largeBlobToPush)
// Pull it
let pulledLargeBlob = try await registry.pullBlob(largeBlobDigest)
var pulledLargeBlob = Data()
try await registry.pullBlob(largeBlobDigest) { buffer in
pulledLargeBlob.append(Data(buffer: buffer))
}
// Ensure that both blobs are identical
XCTAssertEqual(largeBlobToPush, pulledLargeBlob)