diff --git a/.cirrus.yml b/.cirrus.yml index f9921ec..aecdcb7 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -1,10 +1,15 @@ +persistent_worker: + labels: + os: darwin + arch: arm64 + +task: + name: Test + test_script: swift test + task: name: Build only_if: $CIRRUS_TAG == '' - persistent_worker: - labels: - os: darwin - arch: arm64 build_script: swift build --product tart sign_script: codesign --sign - --entitlements Resources/tart.entitlements --force .build/debug/tart binary_artifacts: @@ -13,10 +18,6 @@ task: task: name: Release only_if: $CIRRUS_TAG != '' - persistent_worker: - labels: - os: darwin - arch: arm64 env: GITHUB_TOKEN: ENCRYPTED[!98ace8259c6024da912c14d5a3c5c6aac186890a8d4819fad78f3e0c41a4e0cd3a2537dd6e91493952fb056fa434be7c!] GORELEASER_KEY: ENCRYPTED[!9b80b6ef684ceaf40edd4c7af93014ee156c8aba7e6e5795f41c482729887b5c31f36b651491d790f1f668670888d9fd!] diff --git a/Package.resolved b/Package.resolved index 8896f8c..182f49c 100644 --- a/Package.resolved +++ b/Package.resolved @@ -8,6 +8,33 @@ "revision" : "f3c9084a71ef4376f2fabbdf1d3d90a49f1fabdb", "version" : "1.1.2" } + }, + { + "identity" : "swift-case-paths", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/swift-case-paths", + "state" : { + "revision" : "ce9c0d897db8a840c39de64caaa9b60119cf4be8", + "version" : "0.8.1" + } + }, + { + "identity" : "swift-parsing", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/swift-parsing", + "state" : { + "revision" : "28d32e9ace1c4c43f5e5a177be837a202494c2d5", + "version" : "0.9.2" + } + }, + { + "identity" : "xctest-dynamic-overlay", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/xctest-dynamic-overlay", + "state" : { + "revision" : "50a70a9d3583fe228ce672e8923010c8df2deddd", + "version" : "0.2.1" + } } ], "version" : 2 diff --git a/Package.swift b/Package.swift index 2a3a38c..33c531a 100644 --- a/Package.swift +++ b/Package.swift @@ -12,11 +12,13 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/apple/swift-argument-parser", from: "1.1.2"), + .package(url: "https://github.com/pointfreeco/swift-parsing", from: "0.9.2"), ], targets: [ - .executableTarget(name: "tart", - dependencies: [ - .product(name: "ArgumentParser", package: "swift-argument-parser"), - ]), + .executableTarget(name: "tart", dependencies: [ + .product(name: "ArgumentParser", package: "swift-argument-parser"), + .product(name: "Parsing", package: "swift-parsing"), + ]), + .testTarget(name: "TartTests", dependencies: ["tart"]) ] ) diff --git a/Sources/tart/Commands/Clone.swift b/Sources/tart/Commands/Clone.swift index 7642919..d228662 100644 --- a/Sources/tart/Commands/Clone.swift +++ b/Sources/tart/Commands/Clone.swift @@ -1,7 +1,6 @@ import ArgumentParser import Foundation import SystemConfiguration -import Virtualization struct Clone: AsyncParsableCommand { static var configuration = CommandConfiguration(abstract: "Clone a VM") @@ -14,17 +13,13 @@ struct Clone: AsyncParsableCommand { func run() async throws { do { - let vmStorage = VMStorage() - let sourceVMDir = try vmStorage.read(sourceName) - let newVMDir = try vmStorage.create(newName) + // Pull the VM in case it's OCI-based and doesn't exist locally yet + if let remoteName = try? RemoteName(sourceName), !VMStorageOCI().exists(remoteName) { + let registry = try Registry(host: remoteName.host, namespace: remoteName.namespace) + try await VMStorageOCI().pull(remoteName, registry: registry) + } - try FileManager.default.copyItem(at: sourceVMDir.configURL, to: newVMDir.configURL) - try FileManager.default.copyItem(at: sourceVMDir.nvramURL, to: newVMDir.nvramURL) - try FileManager.default.copyItem(at: sourceVMDir.diskURL, to: newVMDir.diskURL) - - var newVMConfig = try VMConfig(fromURL: newVMDir.configURL) - newVMConfig.macAddress = VZMACAddress.randomLocallyAdministered() - try newVMConfig.save(toURL: newVMDir.configURL) + try VMStorageHelper.open(sourceName).clone(to: VMStorageLocal().create(newName)) Foundation.exit(0) } catch { diff --git a/Sources/tart/Commands/Create.swift b/Sources/tart/Commands/Create.swift index a5f7d1b..6f05c4e 100644 --- a/Sources/tart/Commands/Create.swift +++ b/Sources/tart/Commands/Create.swift @@ -23,7 +23,7 @@ struct Create: AsyncParsableCommand { func run() async throws { do { - let vmDir = try VMStorage().create(name) + let vmDir = try VMStorageLocal().create(name) if fromIPSW! == "latest" { _ = try await VM(vmDir: vmDir, ipswURL: nil, diskSizeGB: diskSize) diff --git a/Sources/tart/Commands/Delete.swift b/Sources/tart/Commands/Delete.swift index 3964742..70337b6 100644 --- a/Sources/tart/Commands/Delete.swift +++ b/Sources/tart/Commands/Delete.swift @@ -10,7 +10,7 @@ struct Delete: AsyncParsableCommand { func run() async throws { do { - try VMStorage().delete(name) + try VMStorageHelper.delete(name) Foundation.exit(0) } catch { diff --git a/Sources/tart/Commands/IP.swift b/Sources/tart/Commands/IP.swift index 3f82072..75d417c 100644 --- a/Sources/tart/Commands/IP.swift +++ b/Sources/tart/Commands/IP.swift @@ -14,7 +14,7 @@ struct IP: AsyncParsableCommand { func run() async throws { do { - let vmDir = try VMStorage().read(name) + let vmDir = try VMStorageLocal().open(name) let vmConfig = try VMConfig.init(fromURL: vmDir.configURL) guard let ip = try await resolveIP(vmConfig, secondsToWait: wait) else { diff --git a/Sources/tart/Commands/List.swift b/Sources/tart/Commands/List.swift index fb35701..56c1e54 100644 --- a/Sources/tart/Commands/List.swift +++ b/Sources/tart/Commands/List.swift @@ -7,9 +7,10 @@ struct List: AsyncParsableCommand { func run() async throws { do { - for vmURL in try VMStorage().list() { - print(vmURL) - } + print("Name\tSource") + + displayTable("local", try VMStorageLocal().list()) + displayTable("oci", try VMStorageOCI().list()) Foundation.exit(0) } catch { @@ -18,4 +19,10 @@ struct List: AsyncParsableCommand { Foundation.exit(1) } } + + private func displayTable(_ source: String, _ vms: [(String, VMDirectory)]) { + for (name, _) in vms { + print("\(source)\t\(name)") + } + } } diff --git a/Sources/tart/Commands/Login.swift b/Sources/tart/Commands/Login.swift new file mode 100644 index 0000000..f45ea77 --- /dev/null +++ b/Sources/tart/Commands/Login.swift @@ -0,0 +1,24 @@ +import ArgumentParser +import Dispatch +import SwiftUI + +struct Login: AsyncParsableCommand { + static var configuration = CommandConfiguration(abstract: "Login to a registry") + + @Argument(help: "host") + var host: String + + func run() async throws { + do { + let (user, password) = try Credentials.retrieveStdin() + + try Credentials.store(host: host, user: user, password: password) + + Foundation.exit(0) + } catch { + print(error) + + Foundation.exit(1) + } + } +} diff --git a/Sources/tart/Commands/Pull.swift b/Sources/tart/Commands/Pull.swift new file mode 100644 index 0000000..da1a111 --- /dev/null +++ b/Sources/tart/Commands/Pull.swift @@ -0,0 +1,27 @@ +import ArgumentParser +import Dispatch +import SwiftUI + +struct Pull: AsyncParsableCommand { + static var configuration = CommandConfiguration(abstract: "Pull a VM from a registry") + + @Argument(help: "remote VM name") + var remoteName: String + + func run() async throws { + do { + let remoteName = try RemoteName(remoteName) + let registry = try Registry(host: remoteName.host, namespace: remoteName.namespace) + + defaultLogger.appendNewLine("pulling \(remoteName)...") + + try await VMStorageOCI().pull(remoteName, registry: registry) + + Foundation.exit(0) + } catch { + print(error) + + Foundation.exit(1) + } + } +} diff --git a/Sources/tart/Commands/Push.swift b/Sources/tart/Commands/Push.swift new file mode 100644 index 0000000..b9ccfa0 --- /dev/null +++ b/Sources/tart/Commands/Push.swift @@ -0,0 +1,53 @@ +import ArgumentParser +import Dispatch +import Foundation +import Compression + +struct Push: AsyncParsableCommand { + static var configuration = CommandConfiguration(abstract: "Push a VM to a registry") + + @Argument(help: "local VM name") + var localName: String + + @Argument(help: "remote VM name(s)") + var remoteNames: [String] + + func run() async throws { + do { + let localVMDir = try VMStorageLocal().open(localName) + + // Parse remote names supplied by the user + let remoteNames = try remoteNames.map{ + try RemoteName($0) + } + + // Group remote names by registry + struct RegistryIdentifier: Hashable, Equatable { + var host: String + var namespace: String + } + + let registryGroups = Dictionary(grouping: remoteNames, by: { + RegistryIdentifier(host: $0.host, namespace: $0.namespace) + }) + + // Push VM + for (registryIdentifier, remoteNamesForRegistry) in registryGroups { + let registry = try Registry(host: registryIdentifier.host, namespace: registryIdentifier.namespace) + + let listOfTagsAndDigests = "{" + remoteNamesForRegistry.map{$0.fullyQualifiedReference } + .joined(separator: ",") + "}" + defaultLogger.appendNewLine("pushing \(localName) to " + + "\(registryIdentifier.host)/\(registryIdentifier.namespace)\(listOfTagsAndDigests)...") + + try await localVMDir.pushToRegistry(registry: registry, references: remoteNamesForRegistry.map{ $0.reference }) + } + + Foundation.exit(0) + } catch { + print(error) + + Foundation.exit(1) + } + } +} diff --git a/Sources/tart/Commands/Run.swift b/Sources/tart/Commands/Run.swift index 511d425..ee27499 100644 --- a/Sources/tart/Commands/Run.swift +++ b/Sources/tart/Commands/Run.swift @@ -12,10 +12,10 @@ struct Run: AsyncParsableCommand { var name: String @Flag var noGraphics: Bool = false - + @MainActor func run() async throws { - let vmDir = try VMStorage().read(name) + let vmDir = try VMStorageLocal().open(name) vm = try VM(vmDir: vmDir) Task { diff --git a/Sources/tart/Commands/Set.swift b/Sources/tart/Commands/Set.swift index 283297e..175cfb4 100644 --- a/Sources/tart/Commands/Set.swift +++ b/Sources/tart/Commands/Set.swift @@ -21,8 +21,7 @@ struct Set: AsyncParsableCommand { func run() async throws { do { - let vmStorage = VMStorage() - let vmDir = try vmStorage.read(name) + let vmDir = try VMStorageLocal().open(name) var vmConfig = try VMConfig(fromURL: vmDir.configURL) if let cpu = cpu { @@ -46,7 +45,7 @@ struct Set: AsyncParsableCommand { } try vmConfig.save(toURL: vmDir.configURL) - + if diskSize != nil { try vmDir.resizeDisk(diskSize!) } diff --git a/Sources/tart/Config.swift b/Sources/tart/Config.swift new file mode 100644 index 0000000..2e73675 --- /dev/null +++ b/Sources/tart/Config.swift @@ -0,0 +1,9 @@ +import Foundation + +struct Config { + public static let tartHomeDir: URL = FileManager.default + .homeDirectoryForCurrentUser + .appendingPathComponent(".tart", isDirectory: true) + + public static let tartCacheDir: URL = tartHomeDir.appendingPathComponent("cache", isDirectory: true) +} diff --git a/Sources/tart/Credentials.swift b/Sources/tart/Credentials.swift new file mode 100644 index 0000000..fea8d3f --- /dev/null +++ b/Sources/tart/Credentials.swift @@ -0,0 +1,72 @@ +import Foundation + +class Credentials { + static func retrieve(host: String) throws -> (String, String) { + do { + return try retrieveKeychain(host: host) + } catch RegistryError.AuthFailed { + return try retrieveStdin() + } + } + + static func retrieveKeychain(host: String) throws -> (String, String) { + let query: [String: Any] = [kSecClass as String: kSecClassInternetPassword, + kSecAttrProtocol as String: kSecAttrProtocolHTTPS, + kSecAttrServer as String: host, + kSecMatchLimit as String: kSecMatchLimitOne, + kSecReturnAttributes as String: true, + kSecReturnData as String: true, + kSecAttrLabel as String: "Tart Credentials", + ] + + var item: CFTypeRef? + let status = SecItemCopyMatching(query as CFDictionary, &item) + + if status != errSecSuccess { + if status == errSecItemNotFound { + throw RegistryError.AuthFailed(why: "Keychain item not found") + } + + throw RegistryError.AuthFailed(why: "Keychain returned unsuccessful status \(status)") + } + + guard let item = item as? [String: Any], + let user = item[kSecAttrAccount as String] as? String, + let passwordData = item[kSecValueData as String] as? Data, + let password = String(data: passwordData, encoding: .utf8) + else { + throw RegistryError.AuthFailed(why: "Keychain item has unexpected format") + } + + return (user, password) + } + + static func retrieveStdin() throws -> (String, String) { + print("User: ", terminator: "") + let user = readLine() ?? "" + + let rawPass = getpass("Password: ") + let pass = String(cString: rawPass!, encoding: .utf8)! + + return (user, pass) + } + + static func store(host: String, user: String, password: String) throws { + let attributes: [String: Any] = [kSecClass as String: kSecClassInternetPassword, + kSecAttrAccount as String: user, + kSecAttrProtocol as String: kSecAttrProtocolHTTPS, + kSecAttrServer as String: host, + kSecValueData as String: password, + kSecAttrLabel as String: "Tart Credentials", + ] + + let status = SecItemAdd(attributes as CFDictionary, nil) + + switch status { + case errSecSuccess, errSecDuplicateItem: + return + default: + throw RegistryError.AuthFailed(why: "Keychain returned unsuccessful status \(status)") + } + } +} diff --git a/Sources/tart/OCI/Digest.swift b/Sources/tart/OCI/Digest.swift new file mode 100644 index 0000000..703c340 --- /dev/null +++ b/Sources/tart/OCI/Digest.swift @@ -0,0 +1,27 @@ +import Foundation +import CryptoKit + +class Digest { + var hash: SHA256 = SHA256() + + func update(_ data: Data) { + hash.update(data: data) + } + + func finalize() -> String { + hash.finalize().hexdigest() + } + + static func hash(_ data: Data) -> String { + SHA256.hash(data: data).hexdigest() + } +} + +extension SHA256.Digest { + func hexdigest() -> String { + "sha256:" + self.map { + String(format: "%02x", $0) + } + .joined() + } +} diff --git a/Sources/tart/OCI/Manifest.swift b/Sources/tart/OCI/Manifest.swift new file mode 100644 index 0000000..4d6bd61 --- /dev/null +++ b/Sources/tart/OCI/Manifest.swift @@ -0,0 +1,28 @@ +import Foundation + +let ociManifestMediaType = "application/vnd.oci.image.manifest.v1+json" +let ociConfigMediaType = "application/vnd.oci.image.config.v1+json" + +struct OCIManifest: Encodable, Decodable { + var schemaVersion: Int = 2 + var mediaType: String = ociManifestMediaType + var config: OCIManifestConfig + var layers: [OCIManifestLayer] = Array() +} + +struct OCIManifestConfig: Encodable, Decodable { + var mediaType: String = ociConfigMediaType + var size: Int + var digest: String +} + +struct OCIManifestLayer: Encodable, Decodable { + var mediaType: String + var size: Int + var digest: String +} + +struct Descriptor { + var size: Int + var digest: String +} diff --git a/Sources/tart/OCI/Registry.swift b/Sources/tart/OCI/Registry.swift new file mode 100644 index 0000000..0ba3cad --- /dev/null +++ b/Sources/tart/OCI/Registry.swift @@ -0,0 +1,227 @@ +import Foundation + +enum RegistryError: Error { + case UnexpectedHTTPStatusCode(when: String, code: Int, details: String = "") + case MissingLocationHeader + case AuthFailed(why: String) + case MalformedHeader(why: String) +} + +struct TokenResponse: Decodable { + var token: String +} + +class Registry { + var baseURL: URL + var namespace: String + var user: String + var password: String + + var token: String? = nil + + init(host: String, namespace: String) throws { + var baseURLComponents = URLComponents() + baseURLComponents.scheme = "https" + baseURLComponents.host = host + baseURLComponents.path = "/v2/" + + baseURL = baseURLComponents.url! + self.namespace = namespace + (user, password) = try Credentials.retrieve(host: host) + } + + func pushManifest(reference: String, config: Descriptor, layers: [OCIManifestLayer]) async throws -> String { + let manifest = OCIManifest(config: OCIManifestConfig(size: config.size, digest: config.digest), + layers: layers) + let manifestJSON = try JSONEncoder().encode(manifest) + + let (responseData, response) = try await endpointRequest("PUT", "\(namespace)/manifests/\(reference)", + headers: ["Content-Type": manifest.mediaType], + body: manifestJSON) + if response.statusCode != 201 { + throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing manifest", code: response.statusCode, + details: String(decoding: responseData, as: UTF8.self)) + } + + return Digest.hash(manifestJSON) + } + + public func pullManifest(reference: String) async throws -> (OCIManifest, Data) { + let (responseData, response) = try await endpointRequest("GET", "\(namespace)/manifests/\(reference)", + headers: ["Accept": ociManifestMediaType]) + if response.statusCode != 200 { + throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling manifest", code: response.statusCode, + details: String(decoding: responseData, as: UTF8.self)) + } + + let manifest = try JSONDecoder().decode(OCIManifest.self, from: responseData) + + return (manifest, responseData) + } + + private func uploadLocationFromResponse(response: HTTPURLResponse) throws -> URLComponents { + guard let uploadLocationRaw = response.value(forHTTPHeaderField: "Location") else { + throw RegistryError.MissingLocationHeader + } + + var uploadLocation = URL(string: uploadLocationRaw)! + + // If the URL provided in the Location header + // is relative — make it absolute. + if uploadLocation.absoluteString == uploadLocation.relativeString { + uploadLocation = URL(string: uploadLocation.path, relativeTo: baseURL)! + } + + return URLComponents(url: uploadLocation, resolvingAgainstBaseURL: true)! + } + + public func pushBlob(fromData: Data, chunkSize: Int = 5 * 1024 * 1024) async throws -> String { + // Initiate a blob upload + let (postData, postResponse) = try await endpointRequest("POST", "\(namespace)/blobs/uploads/", + headers: ["Content-Length": "0"]) + if postResponse.statusCode != 202 { + throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (POST)", code: postResponse.statusCode, + details: String(decoding: postData, as: UTF8.self)) + } + + // Figure out where to upload the blob + let uploadLocation = try uploadLocationFromResponse(response: postResponse) + + // Upload the blob + let headers = [ + "Content-Length": "\(fromData.count)", + "Content-Type": "application/octet-stream", + ] + + let digest = Digest.hash(fromData) + let parameters = [ + "digest": digest, + ] + + let (putData, putResponse) = try await rawRequest("PUT", uploadLocation, headers: headers, parameters: parameters, + body: fromData) + if putResponse.statusCode != 201 { + throw RegistryError.UnexpectedHTTPStatusCode(when: "pushing blob (PUT)", code: putResponse.statusCode, + details: String(decoding: putData, as: UTF8.self)) + } + + return digest + } + + public func pullBlob(_ digest: String) async throws -> Data { + let (putData, putResponse) = try await endpointRequest("GET", "\(namespace)/blobs/\(digest)") + if putResponse.statusCode != 200 { + throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling blob", code: putResponse.statusCode, + details: String(decoding: putData, as: UTF8.self)) + } + + return putData + } + + private func endpointRequest( + _ method: String, + _ endpoint: String, + headers: Dictionary = Dictionary(), + parameters: Dictionary = Dictionary(), + body: Data? = nil + ) async throws -> (Data, HTTPURLResponse) { + let url = URL(string: endpoint, relativeTo: baseURL)! + let urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: true)! + + return try await rawRequest(method, urlComponents, headers: headers, parameters: parameters, body: body) + } + + private func rawRequest( + _ method: String, + _ urlComponents: URLComponents, + headers: Dictionary = Dictionary(), + parameters: Dictionary = Dictionary(), + body: Data? = nil + ) async throws -> (Data, HTTPURLResponse) { + var urlComponents = urlComponents + + if !parameters.isEmpty { + urlComponents.queryItems = Array() + } + urlComponents.queryItems?.append(contentsOf: parameters.map { key, value -> URLQueryItem in + URLQueryItem(name: key, value: value) + }) + + var request = URLRequest(url: urlComponents.url!) + request.httpMethod = method + for (key, value) in headers { + request.addValue(value, forHTTPHeaderField: key) + } + request.httpBody = body + + var (data, response) = try await authAwareRequest(request: request) + + if response.statusCode == 401 { + try await auth(response: response) + (data, response) = try await authAwareRequest(request: request) + } + + return (data, response) + } + + private func auth(response: HTTPURLResponse) async throws { + // Process WWW-Authenticate header + guard let wwwAuthenticateRaw = response.value(forHTTPHeaderField: "WWW-Authenticate") else { + throw RegistryError.AuthFailed(why: "got HTTP 401, but WWW-Authenticate header is missing") + } + + let wwwAuthenticate = try WWWAuthenticate(rawHeaderValue: wwwAuthenticateRaw) + if wwwAuthenticate.scheme != "Bearer" { + throw RegistryError.AuthFailed(why: "WWW-Authenticate header's authentication scheme " + + "\"\(wwwAuthenticate.scheme)\" is unsupported, expected \"Bearer\" scheme") + } + guard let realm = wwwAuthenticate.kvs["realm"] else { + throw RegistryError.AuthFailed(why: "WWW-Authenticate header is missing a \"realm\" directive") + } + + // Request a token + guard var authenticateURL = URLComponents(string: realm) else { + throw RegistryError.AuthFailed(why: "WWW-Authenticate header's realm directive " + + "\"\(realm)\" doesn't look like URL") + } + + // Token Authentication Specification[1]: + // + // >To respond to this challenge, the client will need to make a GET request + // >[...] using the service and scope values from the WWW-Authenticate header. + // + // [1]: https://docs.docker.com/registry/spec/auth/token/ + authenticateURL.queryItems = ["scope", "service"].compactMap { key in + if let value = wwwAuthenticate.kvs[key] { + return URLQueryItem(name: key, value: value) + } else { + return nil + } + } + + let encodedCredentials = "\(user):\(password)".data(using: .utf8)?.base64EncodedString() + let headers = [ + "Authorization": "Basic \(encodedCredentials!)" + ] + + let (responseData, response) = try await rawRequest("GET", authenticateURL, headers: headers) + if response.statusCode != 200 { + throw RegistryError.AuthFailed(why: "received unexpected HTTP status code \(response.statusCode) " + + "while retrieving an authentication token") + } + + token = try JSONDecoder().decode(TokenResponse.self, from: responseData).token + } + + private func authAwareRequest(request: URLRequest) async throws -> (Data, HTTPURLResponse) { + var request = request + + if let token = self.token { + request.addValue("Bearer \(token)", forHTTPHeaderField: "Authorization") + } + + let (responseData, response) = try await URLSession.shared.data(for: request) + + return (responseData, response as! HTTPURLResponse) + } +} diff --git a/Sources/tart/OCI/RemoteName.swift b/Sources/tart/OCI/RemoteName.swift new file mode 100644 index 0000000..50bd61b --- /dev/null +++ b/Sources/tart/OCI/RemoteName.swift @@ -0,0 +1,103 @@ +import Foundation +import Parsing + +struct Tail { + enum TailType { + case Tag + case Digest + } + + var type: TailType + var value: String +} + +struct RemoteName: Comparable, CustomStringConvertible { + var host: String + var namespace: String + var reference: String = "latest" + var fullyQualifiedReference: String { + get { + if reference.starts(with: "sha256:") { + return "@" + reference + } + + return ":" + reference + } + } + + init(host: String, namespace: String, reference: String) { + self.host = host + self.namespace = namespace + self.reference = reference + } + + init(_ name: String) throws { + let csNormal = [ + UInt8(ascii: "a")...UInt8(ascii: "z"), + UInt8(ascii: "0")...UInt8(ascii: "9"), + ].asCharacterSet().union(CharacterSet(charactersIn: "_-.")) + + let csHex = [ + UInt8(ascii: "a")...UInt8(ascii: "f"), + UInt8(ascii: "0")...UInt8(ascii: "9"), + ].asCharacterSet() + + let parser = Parse { + Consumed { + csNormal + Optionally { + ":" + Digits() + } + } + "/" + csNormal.union(CharacterSet(charactersIn: "/")) + Optionally { + OneOf { + Parse { + ":" + csNormal.map { + Tail(type: .Tag, value: String($0)) + } + } + Parse { + "@sha256:" + csHex.map { + Tail(type: .Digest, value: "sha256:" + String($0)) + } + } + } + } + End() + } + + let result = try parser.parse(name) + + host = String(result.0) + namespace = String(result.1) + if let tail = result.2 { + reference = tail.value + } + } + + static func <(lhs: RemoteName, rhs: RemoteName) -> Bool { + if lhs.host != rhs.host { + return lhs.host < rhs.host + } else if lhs.namespace != rhs.namespace { + return lhs.namespace < rhs.namespace + } else { + return lhs.reference < rhs.reference + } + } + + var description: String { + "\(host)/\(namespace)\(fullyQualifiedReference)" + } +} + +extension Array where Self.Element == ClosedRange { + func asCharacterSet() -> CharacterSet { + let characters = self.joined().map { String(UnicodeScalar($0)) }.joined() + return CharacterSet(charactersIn: characters) + } +} diff --git a/Sources/tart/OCI/WWWAuthenticate.swift b/Sources/tart/OCI/WWWAuthenticate.swift new file mode 100644 index 0000000..ac5d5c2 --- /dev/null +++ b/Sources/tart/OCI/WWWAuthenticate.swift @@ -0,0 +1,63 @@ +import Foundation + +// WWW-Authenticate header parser based on details from RFCs[1][2] +/// +// [1]: https://www.rfc-editor.org/rfc/rfc2617#section-3.2.1 +// [2]: https://www.rfc-editor.org/rfc/rfc6750#section-3 +class WWWAuthenticate { + var scheme: String + var kvs: Dictionary = Dictionary() + + init(rawHeaderValue: String) throws { + let splits = rawHeaderValue.split(separator: " ", maxSplits: 1) + + if splits.count == 2 { + scheme = String(splits[0]) + } else { + throw RegistryError.MalformedHeader(why: "WWW-Authenticate header should consist of two parts: " + + "scheme and directives") + } + + let rawDirectives = contextAwareCommaSplit(rawDirectives: String(splits[1])) + + try rawDirectives.forEach { sequence in + let parts = sequence.split(separator: "=", maxSplits: 1) + if parts.count != 2 { + throw RegistryError.MalformedHeader(why: "Each WWW-Authenticate header directive should be in the form of " + + "key=value or key=\"value\"") + } + + let key = String(parts[0]) + var value = String(parts[1]) + value = value.trimmingCharacters(in: CharacterSet(charactersIn: "\"")) + + kvs[key] = value + } + } + + private func contextAwareCommaSplit(rawDirectives: String) -> Array { + var result: Array = Array() + var inQuotation: Bool = false + var accumulator: Array = Array() + + for ch in rawDirectives { + if ch == "," && !inQuotation { + result.append(String(accumulator)) + accumulator.removeAll() + continue + } + + accumulator.append(ch) + + if ch == "\"" { + inQuotation.toggle() + } + } + + if !accumulator.isEmpty { + result.append(String(accumulator)) + } + + return result + } +} diff --git a/Sources/tart/Root.swift b/Sources/tart/Root.swift index 4d930d9..2059138 100644 --- a/Sources/tart/Root.swift +++ b/Sources/tart/Root.swift @@ -4,5 +4,16 @@ import ArgumentParser struct Root: AsyncParsableCommand { static var configuration = CommandConfiguration( commandName: "tart", - subcommands: [Create.self, Clone.self, Run.self, Set.self, List.self, IP.self, Delete.self]) + subcommands: [ + Create.self, + Clone.self, + Run.self, + Set.self, + List.self, + Login.self, + IP.self, + Pull.self, + Push.self, + Delete.self, + ]) } diff --git a/Sources/tart/VM.swift b/Sources/tart/VM.swift index 241b39c..e1423ef 100644 --- a/Sources/tart/VM.swift +++ b/Sources/tart/VM.swift @@ -46,7 +46,7 @@ class VM: NSObject, VZVirtualMachineDelegate, ObservableObject { } - let ipswCacheFolder = VMStorage.tartCacheDir.appendingPathComponent("IPSWs", isDirectory: true) + let ipswCacheFolder = Config.tartCacheDir.appendingPathComponent("IPSWs", isDirectory: true) try FileManager.default.createDirectory(at: ipswCacheFolder, withIntermediateDirectories: true) let expectedIPSWLocation = ipswCacheFolder.appendingPathComponent("\(image.buildVersion).ipsw", isDirectory: false) diff --git a/Sources/tart/VMConfig.swift b/Sources/tart/VMConfig.swift index fbba131..c5c7d73 100644 --- a/Sources/tart/VMConfig.swift +++ b/Sources/tart/VMConfig.swift @@ -60,9 +60,12 @@ struct VMConfig: Codable { memorySize = memorySizeMin } + init(fromData: Data) throws { + self = try JSONDecoder().decode(VMConfig.self, from: fromData) + } + init(fromURL: URL) throws { - let jsonConfigData = try FileHandle.init(forReadingFrom: fromURL).readToEnd()! - self = try JSONDecoder().decode(VMConfig.self, from: jsonConfigData) + self = try Self(fromData: try Data(contentsOf: fromURL)) } func save(toURL: URL) throws { diff --git a/Sources/tart/VMDirectory+OCI.swift b/Sources/tart/VMDirectory+OCI.swift new file mode 100644 index 0000000..75cb21c --- /dev/null +++ b/Sources/tart/VMDirectory+OCI.swift @@ -0,0 +1,145 @@ +import Foundation +import Compression + +enum OCIError: Error { + case ShouldBeExactlyOneLayer + case ShouldBeAtLeastOneLayer + case FailedToCreateDiskFile +} + +extension VMDirectory { + private static let bufferSizeBytes = 64 * 1024 * 1024 + private static let layerLimitBytes = 500 * 1000 * 1000 + + private static let configMediaType = "application/vnd.cirruslabs.tart.config.v1" + private static let diskMediaType = "application/vnd.cirruslabs.tart.disk.v1" + private static let nvramMediaType = "application/vnd.cirruslabs.tart.nvram.v1" + + func pullFromRegistry(registry: Registry, reference: String) async throws { + defaultLogger.appendNewLine("pulling manifest") + + let (manifest, _) = try await registry.pullManifest(reference: reference) + + return try await pullFromRegistry(registry: registry, manifest: manifest) + } + + func pullFromRegistry(registry: Registry, manifest: OCIManifest) async throws { + // Pull VM's config file layer and re-serialize it into a config file + let configLayers = manifest.layers.filter { + $0.mediaType == Self.configMediaType + } + if configLayers.count != 1 { + throw OCIError.ShouldBeExactlyOneLayer + } + let configData = try await registry.pullBlob(configLayers.first!.digest) + try VMConfig(fromData: configData).save(toURL: configURL) + + // Pull VM's disk layers and decompress them sequentially into a disk file + let diskLayers = manifest.layers.filter { + $0.mediaType == Self.diskMediaType + } + if diskLayers.isEmpty { + throw OCIError.ShouldBeAtLeastOneLayer + } + if !FileManager.default.createFile(atPath: diskURL.path, contents: nil) { + throw OCIError.FailedToCreateDiskFile + } + let disk = try FileHandle(forWritingTo: diskURL) + let filter = try OutputFilter(.decompress, using: .lz4, bufferCapacity: Self.bufferSizeBytes) { data in + if let data = data { + disk.write(data) + } + } + + // Progress + let progress = Progress(totalUnitCount: Int64(diskLayers.map{ $0.size }.reduce(0) { $0 + $1 })) + defaultLogger.appendNewLine("pulling disk, \(progress.percentage())") + + for diskLayer in diskLayers { + let diskData = try await registry.pullBlob(diskLayer.digest) + try filter.write(diskData) + + // Progress + progress.completedUnitCount += Int64(diskLayer.size) + defaultLogger.updateLastLine("pulling disk, \(progress.percentage())") + } + try filter.finalize() + try disk.close() + + // Pull VM's NVRAM file layer and store it in an NVRAM file + defaultLogger.appendNewLine("pulling NVRAM") + + let nvramLayers = manifest.layers.filter { + $0.mediaType == Self.nvramMediaType + } + if nvramLayers.count != 1 { + throw OCIError.ShouldBeExactlyOneLayer + } + let nvramData = try await registry.pullBlob(nvramLayers.first!.digest) + try nvramData.write(to: nvramURL) + } + + func pushToRegistry(registry: Registry, references: [String]) async throws { + var layers = Array() + + // Read VM's config and push it as blob + let config = try VMConfig(fromURL: configURL) + let configJSON = try JSONEncoder().encode(config) + let configDigest = try await registry.pushBlob(fromData: configJSON) + layers.append(OCIManifestLayer(mediaType: Self.configMediaType, size: configJSON.count, digest: configDigest)) + + // Progress + let diskSize = try FileManager.default.attributesOfItem(atPath: diskURL.path)[.size] as! Int64 + let progress = Progress(totalUnitCount: diskSize) + defaultLogger.appendNewLine("pushing disk, \(progress.percentage())") + + // Read VM's compressed disk as chunks + // and sequentially upload them as blobs + let disk = try FileHandle(forReadingFrom: diskURL) + let compressingFilter = try InputFilter(.compress, using: .lz4, bufferCapacity: Self.bufferSizeBytes) { _ in + let data = try disk.read(upToCount: Self.bufferSizeBytes) + + // Progress + progress.completedUnitCount += Int64(data?.count ?? 0) + + return data + } + while let chunk = try compressingFilter.readData(ofLength: Self.layerLimitBytes) { + let chunkDigest = try await registry.pushBlob(fromData: chunk) + layers.append(OCIManifestLayer(mediaType: Self.diskMediaType, size: chunk.count, digest: chunkDigest)) + + // Progress + defaultLogger.updateLastLine("pushing disk, \(progress.percentage())") + } + + // Read VM's NVRAM and push it as blob + defaultLogger.appendNewLine("pushing NVRAM") + + let nvram = try FileHandle(forReadingFrom: nvramURL).readToEnd()! + let nvramDigest = try await registry.pushBlob(fromData: nvram) + layers.append(OCIManifestLayer(mediaType: Self.nvramMediaType, size: nvram.count, digest: nvramDigest)) + + // Craft a stub OCI config for Docker Hub compatibility + struct OCIConfig: Encodable, Decodable { + var architecture: String = "arm64" + var os: String = "darwin" + } + + let ociConfigJSON = try JSONEncoder().encode(OCIConfig()) + let ociConfigDigest = try await registry.pushBlob(fromData: ociConfigJSON) + let ociConfigDescriptor = Descriptor(size: ociConfigJSON.count, digest: ociConfigDigest) + + // Manifest + for reference in references { + defaultLogger.appendNewLine("pushing manifest") + + _ = try await registry.pushManifest(reference: reference, config: ociConfigDescriptor, layers: layers) + } + } +} + +extension Progress { + func percentage() -> String { + String(Int(100 * fractionCompleted)) + "%" + } +} diff --git a/Sources/tart/VMDirectory.swift b/Sources/tart/VMDirectory.swift index 6510790..a675e0b 100644 --- a/Sources/tart/VMDirectory.swift +++ b/Sources/tart/VMDirectory.swift @@ -1,4 +1,5 @@ import Foundation +import Virtualization struct UninitializedVMDirectoryError: Error { } @@ -7,7 +8,6 @@ struct AlreadyInitializedVMDirectoryError: Error { } struct VMDirectory { - var name: String var baseURL: URL var configURL: URL { @@ -20,18 +20,26 @@ struct VMDirectory { baseURL.appendingPathComponent("nvram.bin") } + var name: String { + baseURL.lastPathComponent + } + var initialized: Bool { FileManager.default.fileExists(atPath: configURL.path) && FileManager.default.fileExists(atPath: diskURL.path) && FileManager.default.fileExists(atPath: nvramURL.path) } - func initialize() throws { - if initialized { + func initialize(overwrite: Bool = false) throws { + if !overwrite && initialized { throw AlreadyInitializedVMDirectoryError() } try FileManager.default.createDirectory(at: baseURL, withIntermediateDirectories: true, attributes: nil) + + try? FileManager.default.removeItem(at: configURL) + try? FileManager.default.removeItem(at: diskURL) + try? FileManager.default.removeItem(at: nvramURL) } func validate() throws { @@ -39,7 +47,18 @@ struct VMDirectory { throw UninitializedVMDirectoryError() } } - + + func clone(to: VMDirectory) throws { + try FileManager.default.copyItem(at: configURL, to: to.configURL) + try FileManager.default.copyItem(at: nvramURL, to: to.nvramURL) + try FileManager.default.copyItem(at: diskURL, to: to.diskURL) + + // Re-generate MAC address + var newVMConfig = try VMConfig(fromURL: to.configURL) + newVMConfig.macAddress = VZMACAddress.randomLocallyAdministered() + try newVMConfig.save(toURL: to.configURL) + } + func resizeDisk(_ sizeGB: UInt8) throws { if !FileManager.default.fileExists(atPath: diskURL.path) { FileManager.default.createFile(atPath: diskURL.path, contents: nil, attributes: nil) diff --git a/Sources/tart/VMStorage.swift b/Sources/tart/VMStorage.swift deleted file mode 100644 index b8b0b02..0000000 --- a/Sources/tart/VMStorage.swift +++ /dev/null @@ -1,58 +0,0 @@ -import Foundation - -struct VMStorage { - public static let tartHomeDir: URL = FileManager.default - .homeDirectoryForCurrentUser - .appendingPathComponent(".tart", isDirectory: true) - - public static let tartVMsDir: URL = tartHomeDir.appendingPathComponent("vms", isDirectory: true) - public static let tartCacheDir: URL = tartHomeDir.appendingPathComponent("cache", isDirectory: true) - - func create(_ name: String) throws -> VMDirectory { - let vmDir = VMDirectory(name: name, baseURL: vmURL(name)) - - try vmDir.initialize() - - return vmDir - } - - func read(_ name: String) throws -> VMDirectory { - let vmDir = VMDirectory(name: name, baseURL: vmURL(name)) - - try vmDir.validate() - - return vmDir - } - - func delete(_ name: String) throws { - try FileManager.default.removeItem(at: vmURL(name)) - } - - func list() throws -> [URL] { - do { - return try FileManager.default.contentsOfDirectory( - at: VMStorage.tartVMsDir, - includingPropertiesForKeys: [.isDirectoryKey], - options: .skipsSubdirectoryDescendants) - } catch { - if error.isFileNotFound() { - return [] - } - - throw error - } - } - - private func vmURL(_ name: String) -> URL { - return URL.init( - fileURLWithPath: name, - isDirectory: true, - relativeTo: VMStorage.tartVMsDir) - } -} - -extension Error { - func isFileNotFound() -> Bool { - return (self as NSError).code == NSFileReadNoSuchFileError - } -} diff --git a/Sources/tart/VMStorageHelper.swift b/Sources/tart/VMStorageHelper.swift new file mode 100644 index 0000000..664a47a --- /dev/null +++ b/Sources/tart/VMStorageHelper.swift @@ -0,0 +1,53 @@ +import Foundation + +class VMStorageHelper { + static func open(_ name: String) throws -> VMDirectory { + try missingVMWrap(name) { + if let remoteName = try? RemoteName(name) { + return try VMStorageOCI().open(remoteName) + } else { + return try VMStorageLocal().open(name) + } + } + } + + static func delete(_ name: String) throws { + try missingVMWrap(name) { + if let remoteName = try? RemoteName(name) { + try VMStorageOCI().delete(remoteName) + } else { + try VMStorageLocal().delete(name) + } + } + } + + private static func missingVMWrap(_ name: String, closure: () throws -> R) throws -> R { + do { + return try closure() + } catch { + if error.isFileNotFound() { + throw RuntimeError("source VM \"\(name)\" not found, is it listed in \"tart list\"?") + } + + throw error + } + } +} + +extension Error { + func isFileNotFound() -> Bool { + (self as NSError).code == NSFileReadNoSuchFileError + } +} + +class RuntimeError: Error, CustomStringConvertible { + let message: String + + init(_ message: String) { + self.message = message + } + + var description: String { + message + } +} diff --git a/Sources/tart/VMStorageLocal.swift b/Sources/tart/VMStorageLocal.swift new file mode 100644 index 0000000..3c35bf9 --- /dev/null +++ b/Sources/tart/VMStorageLocal.swift @@ -0,0 +1,52 @@ +import Foundation + +class VMStorageLocal { + let baseURL: URL = Config.tartHomeDir.appendingPathComponent("vms", isDirectory: true) + + private func vmURL(_ name: String) -> URL { + baseURL.appendingPathComponent(name, isDirectory: true) + } + + func exists(_ name: String) -> Bool { + VMDirectory(baseURL: vmURL(name)).initialized + } + + func open(_ name: String) throws -> VMDirectory { + let vmDir = VMDirectory(baseURL: vmURL(name)) + + try vmDir.validate() + + return vmDir + } + + func create(_ name: String, overwrite: Bool = false) throws -> VMDirectory { + let vmDir = VMDirectory(baseURL: vmURL(name)) + + try vmDir.initialize(overwrite: overwrite) + + return vmDir + } + + func delete(_ name: String) throws { + try FileManager.default.removeItem(at: vmURL(name)) + } + + func list() throws -> [(String, VMDirectory)] { + do { + return try FileManager.default.contentsOfDirectory( + at: baseURL, + includingPropertiesForKeys: [.isDirectoryKey], + options: .skipsSubdirectoryDescendants).map { url in + let vmDir = VMDirectory(baseURL: url) + + return (vmDir.name, vmDir) + } + } catch { + if error.isFileNotFound() { + return [] + } + + throw error + } + } +} diff --git a/Sources/tart/VMStorageOCI.swift b/Sources/tart/VMStorageOCI.swift new file mode 100644 index 0000000..1e7f40c --- /dev/null +++ b/Sources/tart/VMStorageOCI.swift @@ -0,0 +1,99 @@ +import Foundation + +class VMStorageOCI { + let baseURL = Config.tartCacheDir.appendingPathComponent("OCIs", isDirectory: true) + + private func vmURL(_ name: RemoteName) -> URL { + baseURL.appendingRemoteName(name) + } + + func exists(_ name: RemoteName) -> Bool { + VMDirectory(baseURL: vmURL(name)).initialized + } + + func open(_ name: RemoteName) throws -> VMDirectory { + let vmDir = VMDirectory(baseURL: vmURL(name)) + + try vmDir.validate() + + return vmDir + } + + func create(_ name: RemoteName, overwrite: Bool = false) throws -> VMDirectory { + let vmDir = VMDirectory(baseURL: vmURL(name)) + + try vmDir.initialize(overwrite: overwrite) + + return vmDir + } + + func delete(_ name: RemoteName) throws { + try FileManager.default.removeItem(at: vmURL(name)) + } + + func list() throws -> [(String, VMDirectory)] { + var result: [(String, VMDirectory)] = Array() + + guard let enumerator = FileManager.default.enumerator(at: baseURL, + includingPropertiesForKeys: [.isSymbolicLinkKey], options: [.producesRelativePathURLs]) else { + return [] + } + + for case let foundURL as URL in enumerator { + let vmDir = VMDirectory(baseURL: foundURL) + + if !vmDir.initialized { + continue + } + + let parts = [foundURL.deletingLastPathComponent().relativePath, foundURL.lastPathComponent] + var name: String + + if try foundURL.resourceValues(forKeys: [.isSymbolicLinkKey]).isSymbolicLink! { + name = parts.joined(separator: ":") + } else { + name = parts.joined(separator: "@") + } + + result.append((name, vmDir)) + } + + return result + } + + func pull(_ name: RemoteName, registry: Registry) async throws { + defaultLogger.appendNewLine("pulling manifest") + + let (manifest, manifestData) = try await registry.pullManifest(reference: name.reference) + + // Create directory for manifest's digest + var digestName = name + digestName.reference = Digest.hash(manifestData) + if !exists(digestName) { + let vmDir = try create(digestName) + try await vmDir.pullFromRegistry(registry: registry, manifest: manifest) + } + + // Create directory for reference if it's different + if digestName != name { + // Overwrite the old symbolic link + if FileManager.default.fileExists(atPath: vmURL(name).path) { + try FileManager.default.removeItem(at: vmURL(name)) + } + + try FileManager.default.createSymbolicLink(at: vmURL(name), withDestinationURL: vmURL(digestName)) + } + } +} + +extension URL { + func appendingRemoteName(_ name: RemoteName) -> URL { + var result: URL = self + + for pathComponent in (name.host + "/" + name.namespace + "/" + name.reference).split(separator: "/") { + result = result.appendingPathComponent(String(pathComponent)) + } + + return result + } +} diff --git a/Tests/TartTests/DigestTests.swift b/Tests/TartTests/DigestTests.swift new file mode 100644 index 0000000..1c6fa56 --- /dev/null +++ b/Tests/TartTests/DigestTests.swift @@ -0,0 +1,24 @@ +import XCTest +@testable import tart + +final class DigestTests: XCTestCase { + func testEmptyData() throws { + let data = Data("".utf8) + + let digest = Digest() + digest.update(data) + XCTAssertEqual(digest.finalize(), "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + + XCTAssertEqual(Digest.hash(data), "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + } + + func testNonEmptyData() throws { + let data = Data("The quick brown fox jumps over the lazy dog".utf8) + + let digest = Digest() + digest.update(data) + XCTAssertEqual(digest.finalize(), "sha256:d7a8fbb307d7809469ca9abcb0082e4f8d5651e46d3cdb762d02d0bf37c9e592") + + XCTAssertEqual(Digest.hash(data), "sha256:d7a8fbb307d7809469ca9abcb0082e4f8d5651e46d3cdb762d02d0bf37c9e592") + } +} diff --git a/Tests/TartTests/RemoteNameTests.swift b/Tests/TartTests/RemoteNameTests.swift new file mode 100644 index 0000000..bd34d0e --- /dev/null +++ b/Tests/TartTests/RemoteNameTests.swift @@ -0,0 +1,41 @@ +import XCTest +@testable import tart + +final class RemoteNameTests: XCTestCase { + func testTag() throws { + let expectedRemoteName = RemoteName(host: "ghcr.io", namespace: "a/b", reference: "latest") + + XCTAssertEqual(expectedRemoteName, try RemoteName("ghcr.io/a/b:latest")) + } + + func testDigest() throws { + let expectedRemoteName = RemoteName( + host: "ghcr.io", + namespace: "a/b", + reference: "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ) + + XCTAssertEqual(expectedRemoteName, + try RemoteName("ghcr.io/a/b@sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")) + } + + func testASCIIOnly() throws { + // Only ASCII letters are supported + XCTAssertEqual(try? RemoteName("touché.fr/a/b:latest"), nil) + XCTAssertEqual(try? RemoteName("ghcr.io/tou/ché:latest"), nil) + XCTAssertEqual(try? RemoteName("ghcr.io/a/b:touché"), nil) + } + + func testLocal() throws { + // Local image names (those that don't include a registry) are not supported + XCTAssertEqual(try? RemoteName("debian:latest"), nil) + } + + func testPort() throws { + // Port is included in host + XCTAssertEqual(try RemoteName("127.0.0.1:8080/a/b").host, "127.0.0.1:8080") + + // Port must be specified when ":" is used + XCTAssertEqual(try? RemoteName("127.0.0.1:/a/b").host, nil) + } +} diff --git a/Tests/TartTests/WWWAuthenticateTests.swift b/Tests/TartTests/WWWAuthenticateTests.swift new file mode 100644 index 0000000..0857cda --- /dev/null +++ b/Tests/TartTests/WWWAuthenticateTests.swift @@ -0,0 +1,41 @@ +import XCTest +@testable import tart + +final class WWWAuthenticateTests: XCTestCase { + func testExample() throws { + // Test example from Token Authentication Specification[1] + // + // [1]: https://docs.docker.com/registry/spec/auth/token/ + let wwwAuthenticate = try WWWAuthenticate(rawHeaderValue: "Bearer realm=\"https://auth.docker.io/token\",service=\"registry.docker.io\",scope=\"repository:samalba/my-app:pull,push\"") + + XCTAssertEqual("Bearer", wwwAuthenticate.scheme) + XCTAssertEqual([ + "realm": "https://auth.docker.io/token", + "service": "registry.docker.io", + "scope": "repository:samalba/my-app:pull,push", + ], wwwAuthenticate.kvs) + } + + func testBasic() throws { + let wwwAuthenticate = try WWWAuthenticate(rawHeaderValue: "Bearer a=b,c=\"d\"") + + XCTAssertEqual("Bearer", wwwAuthenticate.scheme) + XCTAssertEqual(["a": "b", "c": "d"], wwwAuthenticate.kvs) + } + + func testIncompleteHeader() throws { + XCTAssertThrowsError(try WWWAuthenticate(rawHeaderValue: "Whatever")) { + XCTAssertTrue($0 is RegistryError) + } + + XCTAssertThrowsError(try WWWAuthenticate(rawHeaderValue: "Bearer ")) { + XCTAssertTrue($0 is RegistryError) + } + } + + func testIncompleteDirective() throws { + XCTAssertThrowsError(try WWWAuthenticate(rawHeaderValue: "Bearer whatever")) { + XCTAssertTrue($0 is RegistryError) + } + } +}