Skip to content

Commit

Permalink
Merge pull request #3875 from aws-amplify/5d/appsync-auth-header
Browse files Browse the repository at this point in the history
fix(api): append auth info as head fields for appSync realtime handshake request
  • Loading branch information
5d authored Sep 19, 2024
2 parents 36e5e92 + 7740b51 commit 949d160
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import Foundation

public enum AppSyncRealTimeRequestAuth {
private static let jsonEncoder = JSONEncoder()
private static let jsonDecoder = JSONDecoder()

case authToken(AuthToken)
case apiKey(ApiKey)
case iam(IAM)
Expand All @@ -31,33 +34,10 @@ public enum AppSyncRealTimeRequestAuth {
let amzDate: String
}

public struct URLQuery {
let header: AppSyncRealTimeRequestAuth
let payload: String

init(header: AppSyncRealTimeRequestAuth, payload: String = "{}") {
self.header = header
self.payload = payload
}

func withBaseURL(_ url: URL, encoder: JSONEncoder? = nil) -> URL {
let jsonEncoder: JSONEncoder = encoder ?? JSONEncoder()
guard let headerJsonData = try? jsonEncoder.encode(header) else {
return url
}

guard var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false)
else {
return url
}

urlComponents.queryItems = [
URLQueryItem(name: "header", value: headerJsonData.base64EncodedString()),
URLQueryItem(name: "payload", value: try? payload.base64EncodedString())
]

return urlComponents.url ?? url
}
var authHeaders: [String: String] {
(try? Self.jsonEncoder.encode(self)).flatMap {
try? Self.jsonDecoder.decode([String: String].self, from: $0)
} ?? [:]
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ class APIKeyAuthInterceptor {
}

extension APIKeyAuthInterceptor: WebSocketInterceptor {
func interceptConnection(url: URL) async -> URL {

func interceptConnection(request: URLRequest) async -> URLRequest {
guard let url = request.url else { return request }

let authHeader = getAuthHeader(apiKey, AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!)
return AppSyncRealTimeRequestAuth.URLQuery(
header: .apiKey(authHeader)
).withBaseURL(url)
return request.injectAppSyncAuthToRequestHeader(auth: .apiKey(authHeader))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,16 @@ extension AuthTokenInterceptor: AppSyncRequestInterceptor {
}

extension AuthTokenInterceptor: WebSocketInterceptor {
func interceptConnection(url: URL) async -> URL {
func interceptConnection(request: URLRequest) async -> URLRequest {
guard let url = request.url else { return request }
let authToken = await getAuthToken()

return AppSyncRealTimeRequestAuth.URLQuery(
header: .authToken(.init(
return request.injectAppSyncAuthToRequestHeader(
auth: .authToken(.init(
host: AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).host!,
authToken: authToken
))
).withBaseURL(url)
)
))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,14 @@ class IAMAuthInterceptor {
}

extension IAMAuthInterceptor: WebSocketInterceptor {
func interceptConnection(url: URL) async -> URL {

func interceptConnection(request: URLRequest) async -> URLRequest {
guard let url = request.url else { return request }
let connectUrl = AppSyncRealTimeClientFactory.appSyncApiEndpoint(url).appendingPathComponent("connect")
guard let authHeader = await getAuthHeader(connectUrl, with: "{}") else {
return connectUrl
return request
}

return AppSyncRealTimeRequestAuth.URLQuery(
header: .iam(authHeader)
).withBaseURL(url)
return request.injectAppSyncAuthToRequestHeader(auth: .iam(authHeader))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ actor AppSyncRealTimeClientFactory: AppSyncRealTimeClientFactoryProtocol {
extension AppSyncRealTimeClientFactory {

/**
Converting appsync api url to realtime api url
1. api.example.com/graphql -> api.example.com/graphql/realtime
2. abc.appsync-api.us-east-1.amazonaws.com/graphql -> abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql
Converting appsync api url to realtime api url, realtime endpoint has scheme 'wss'
1. api.example.com/graphql -> wss://api.example.com/graphql/realtime
2. abc.appsync-api.us-east-1.amazonaws.com/graphql -> wss://abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql
*/
static func appSyncRealTimeEndpoint(_ url: URL) -> URL {
guard let host = url.host else {
Expand All @@ -145,6 +145,7 @@ extension AppSyncRealTimeClientFactory {
}

urlComponents.host = host.replacingOccurrences(of: "appsync-api", with: "appsync-realtime-api")
urlComponents.scheme = "wss"
guard let realTimeUrl = urlComponents.url else {
return url
}
Expand All @@ -153,9 +154,9 @@ extension AppSyncRealTimeClientFactory {
}

/**
Converting appsync realtime api url to api url
Converting appsync realtime api url to api url, api endpoint has scheme 'https'
1. api.example.com/graphql/realtime -> api.example.com/graphql
2. abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql -> abc.appsync-api.us-east-1.amazonaws.com/graphql
2. abc.appsync-realtime-api.us-east-1.amazonaws.com/graphql -> https://abc.appsync-api.us-east-1.amazonaws.com/graphql
*/
static func appSyncApiEndpoint(_ url: URL) -> URL {
guard let host = url.host else {
Expand All @@ -174,6 +175,7 @@ extension AppSyncRealTimeClientFactory {
}

urlComponents.host = host.replacingOccurrences(of: "appsync-realtime-api", with: "appsync-api")
urlComponents.scheme = "https"
guard let apiUrl = urlComponents.url else {
return url
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//


import Foundation

extension URLRequest {
func injectAppSyncAuthToRequestHeader(auth: AppSyncRealTimeRequestAuth) -> URLRequest {
var requstCopy = self
auth.authHeaders.forEach { requstCopy.setValue($0.value, forHTTPHeaderField: $0.key) }
return requstCopy
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,61 +147,6 @@ class AppSyncRealTimeRequestAuthTests: XCTestCase {
""".shrink())
}

func testAppSyncRealTimeRequestAuth_URLQueryWithCognitoAuthHeader() {
let expectedURL = """
https://example.com?\
header=eyJBdXRob3JpemF0aW9uIjoiNDk4NTljN2MtNzQwNS00ZDU4LWFmZjctNTJiZ\
TRiNDczNTU3IiwiaG9zdCI6ImV4YW1wbGUuY29tIn0%3D\
&payload=e30%3D
"""
let encodedURL = AppSyncRealTimeRequestAuth.URLQuery(
header: .authToken(.init(
host: "example.com",
authToken: "49859c7c-7405-4d58-aff7-52be4b473557"
))
).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder)
XCTAssertEqual(encodedURL.absoluteString, expectedURL)
}

func testAppSyncRealTimeRequestAuth_URLQueryWithApiKeyAuthHeader() {
let expectedURL = """
https://example.com?\
header=eyJob3N0IjoiZXhhbXBsZS5jb20iLCJ4LWFtei1kYXRlIjoiOWUwZTJkZjktMmVlNy00NjU5L\
TgzNjItMWM4ODFlMTE4YzlmIiwieC1hcGkta2V5IjoiNjVlMmZhY2EtOGUxZS00ZDM3LThkYzctNjQ0N\
2Q5Njk4MjQ3In0%3D\
&payload=e30%3D
"""
let encodedURL = AppSyncRealTimeRequestAuth.URLQuery(
header: .apiKey(.init(
host: "example.com",
apiKey: "65e2faca-8e1e-4d37-8dc7-6447d9698247",
amzDate: "9e0e2df9-2ee7-4659-8362-1c881e118c9f"
))
).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder)
XCTAssertEqual(encodedURL.absoluteString, expectedURL)
}

func testAppSyncRealTimeRequestAuth_URLQueryWithIAMAuthHeader() {

let expectedURL = """
https://example.com?\
header=eyJhY2NlcHQiOiJhcHBsaWNhdGlvblwvanNvbiwgdGV4dFwvamF2YXNjcmlwdCIsIkF1dGhvcml6YXR\
pb24iOiJjOWRhZDg5Ny05MGQxLTRhNGMtYTVjOS0yYjM2YTI0NzczNWYiLCJjb250ZW50LWVuY29kaW5nIjoiY\
W16LTEuMCIsImNvbnRlbnQtdHlwZSI6ImFwcGxpY2F0aW9uXC9qc29uOyBjaGFyc2V0PVVURi04IiwiaG9zdCI\
6ImV4YW1wbGUuY29tIiwieC1hbXotZGF0ZSI6IjllMGUyZGY5LTJlZTctNDY1OS04MzYyLTFjODgxZTExOGM5Z\
iIsIlgtQW16LVNlY3VyaXR5LVRva2VuIjoiZTdlNjI2OWUtZmRhMS00ZGUwLThiZGItYmFhN2I2ZGQwYTBkIn0%3D\
&payload=e30%3D
"""
let encodedURL = AppSyncRealTimeRequestAuth.URLQuery(
header: .iam(.init(
host: "example.com",
authToken: "c9dad897-90d1-4a4c-a5c9-2b36a247735f",
securityToken: "e7e6269e-fda1-4de0-8bdb-baa7b6dd0a0d",
amzDate: "9e0e2df9-2ee7-4659-8362-1c881e118c9f"))
).withBaseURL(URL(string: "https://example.com")!, encoder: jsonEncoder)
XCTAssertEqual(encodedURL.absoluteString, expectedURL)
}

private func toJson(_ value: Encodable) -> String? {
return try? String(data: jsonEncoder.encode(value), encoding: .utf8)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,13 @@ import Amplify

class APIKeyAuthInterceptorTests: XCTestCase {

func testInterceptConnection_addApiKeySignatureInURLQuery() async {
func testInterceptConnection_addApiKeyInRequestHeader() async {
let apiKey = UUID().uuidString
let interceptor = APIKeyAuthInterceptor(apiKey: apiKey)
let resultUrl = await interceptor.interceptConnection(url: URL(string: "https://example.com")!)
guard let components = URLComponents(url: resultUrl, resolvingAgainstBaseURL: false) else {
XCTFail("Failed to decode decorated URL")
return
}

let header = components.queryItems?.first { $0.name == "header" }
XCTAssertNotNil(header?.value)
let headerData = try! header?.value!.base64DecodedString().data(using: .utf8)
let decodedHeader = try! JSONDecoder().decode(JSONValue.self, from: headerData!)
XCTAssertEqual(decodedHeader["x-api-key"]?.stringValue, apiKey)
let resultUrlRequest = await interceptor.interceptConnection(request: URLRequest(url: URL(string: "https://example.com")!))

let header = resultUrlRequest.value(forHTTPHeaderField: "x-api-key")
XCTAssertEqual(header, apiKey)
}

func testInterceptRequest_appendAuthInfoInPayload() async {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,56 +13,24 @@ import Amplify

class CognitoAuthInterceptorTests: XCTestCase {

func testInterceptConnection_withAuthTokenProvider_appendCorrectAuthHeaderToQuery() async {
func testInterceptConnection_withAuthTokenProvider_appendCorrectAuthHeader() async {
let authTokenProvider = MockAuthTokenProvider()
let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider)

let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!)
guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else {
XCTFail("Failed to get url components from decorated URL")
return
}
let decoratedURLRequest = await interceptor.interceptConnection(request: URLRequest(url:URL(string: "https://example.com")!))

guard let queryHeaderString =
try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString()
else {
XCTFail("Failed to extract header field from query string")
return
}

guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!)
else {
XCTFail("Failed to decode query header to json object")
return
}
XCTAssertEqual(authTokenProvider.authToken, queryHeader.Authorization?.stringValue)
XCTAssertEqual("example.com", queryHeader.host?.stringValue)
XCTAssertEqual(authTokenProvider.authToken, decoratedURLRequest.value(forHTTPHeaderField: "Authorization"))
XCTAssertEqual("example.com", decoratedURLRequest.value(forHTTPHeaderField: "host"))
}

func testInterceptConnection_withAuthTokenProviderFailed_appendEmptyAuthHeaderToQuery() async {
func testInterceptConnection_withAuthTokenProviderFailed_appendEmptyAuthHeader() async {
let authTokenProvider = MockAuthTokenProviderFailed()
let interceptor = AuthTokenInterceptor(authTokenProvider: authTokenProvider)

let decoratedURL = await interceptor.interceptConnection(url: URL(string: "https://example.com")!)
guard let components = URLComponents(url: decoratedURL, resolvingAgainstBaseURL: false) else {
XCTFail("Failed to get url components from decorated URL")
return
}
let decoratedURLRequest = await interceptor.interceptConnection(request: URLRequest(url:URL(string: "https://example.com")!))

guard let queryHeaderString =
try? components.queryItems?.first(where: { $0.name == "header" })?.value?.base64DecodedString()
else {
XCTFail("Failed to extract header field from query string")
return
}

guard let queryHeader = try? JSONDecoder().decode(JSONValue.self, from: queryHeaderString.data(using: .utf8)!)
else {
XCTFail("Failed to decode query header to json object")
return
}
XCTAssertEqual("", queryHeader.Authorization?.stringValue)
XCTAssertEqual("example.com", queryHeader.host?.stringValue)
XCTAssertEqual("", decoratedURLRequest.value(forHTTPHeaderField: "Authorization"))
XCTAssertEqual("example.com", decoratedURLRequest.value(forHTTPHeaderField: "host"))
}

func testInterceptRequest_withAuthTokenProvider_appendCorrectAuthInfoToPayload() async {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class AppSyncRealTimeClientFactoryTests: XCTestCase {
let appSyncEndpoint = URL(string: "https://abc.appsync-api.amazonaws.com/graphql")!
XCTAssertEqual(
AppSyncRealTimeClientFactory.appSyncRealTimeEndpoint(appSyncEndpoint),
URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql")
URL(string: "wss://abc.appsync-realtime-api.amazonaws.com/graphql")
)
}

func testAppSyncRealTimeEndpoint_withAWSAppSyncRealTimeDomain_returnTheSameDomain() {
let appSyncEndpoint = URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql")!
let appSyncEndpoint = URL(string: "wss://abc.appsync-realtime-api.amazonaws.com/graphql")!
XCTAssertEqual(
AppSyncRealTimeClientFactory.appSyncRealTimeEndpoint(appSyncEndpoint),
URL(string: "https://abc.appsync-realtime-api.amazonaws.com/graphql")
URL(string: "wss://abc.appsync-realtime-api.amazonaws.com/graphql")
)
}

Expand All @@ -34,4 +34,28 @@ class AppSyncRealTimeClientFactoryTests: XCTestCase {
URL(string: "https://test.example.com/graphql/realtime")
)
}

func testAppSyncApiEndpoint_withAWSAppSyncRealTimeDomain_returnCorrectApiDomain() {
let appSyncEndpoint = URL(string: "wss://abc.appsync-realtime-api.amazonaws.com/graphql")!
XCTAssertEqual(
AppSyncRealTimeClientFactory.appSyncApiEndpoint(appSyncEndpoint),
URL(string: "https://abc.appsync-api.amazonaws.com/graphql")
)
}

func testAppSyncApiEndpoint_withAWSAppSyncApiDomain_returnTheSameDomain() {
let appSyncEndpoint = URL(string: "https://abc.appsync-api.amazonaws.com/graphql")!
XCTAssertEqual(
AppSyncRealTimeClientFactory.appSyncApiEndpoint(appSyncEndpoint),
URL(string: "https://abc.appsync-api.amazonaws.com/graphql")
)
}

func testAppSyncApiEndpoint_withCustomDomain_returnCorrectRealtimePath() {
let appSyncEndpoint = URL(string: "https://test.example.com/graphql")!
XCTAssertEqual(
AppSyncRealTimeClientFactory.appSyncApiEndpoint(appSyncEndpoint),
URL(string: "https://test.example.com/graphql")
)
}
}
Loading

0 comments on commit 949d160

Please sign in to comment.