Move validateSession back into refreshSessionIfNeeded
This commit is contained in:
		
							parent
							
								
									ad8ce2f6a4
								
							
						
					
					
						commit
						d9e0933e54
					
				| 
						 | 
					@ -103,13 +103,7 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = s.refreshSessionIfNeeded(rw, req, session)
 | 
						err = s.refreshSessionIfNeeded(rw, req, session)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Errorf("error refreshing access token for session (%s): %v", session, err)
 | 
							return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Validate all sessions after any Redeem/Refresh operation (fail or success)
 | 
					 | 
				
			||||||
	err = s.validateSession(req.Context(), session)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return session, nil
 | 
						return session, nil
 | 
				
			||||||
| 
						 | 
					@ -133,11 +127,22 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
 | 
				
			||||||
	// it should be updated after lock is released.
 | 
						// it should be updated after lock is released.
 | 
				
			||||||
	if wasLocked {
 | 
						if wasLocked {
 | 
				
			||||||
		logger.Printf("Update session from store instead of refreshing")
 | 
							logger.Printf("Update session from store instead of refreshing")
 | 
				
			||||||
		return s.updateSessionFromStore(req, session)
 | 
							err = s.updateSessionFromStore(req, session)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logger.Errorf("Unable to update session from store: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
 | 
				
			||||||
 | 
							err = s.refreshSession(rw, req, session)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								// If a preemptive refresh fails, we still keep the session
 | 
				
			||||||
 | 
								// if validateSession succeeds.
 | 
				
			||||||
 | 
								logger.Errorf("Unable to refresh session: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
 | 
						// Validate all sessions after any Redeem/Refresh operation (fail or success)
 | 
				
			||||||
	return s.refreshSession(rw, req, session)
 | 
						return s.validateSession(req.Context(), session)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// refreshSession attempts to refresh the session with the provider
 | 
					// refreshSession attempts to refresh the session with the provider
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -460,6 +460,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
			session           *sessionsapi.SessionState
 | 
								session           *sessionsapi.SessionState
 | 
				
			||||||
			expectedErr       error
 | 
								expectedErr       error
 | 
				
			||||||
			expectRefreshed   bool
 | 
								expectRefreshed   bool
 | 
				
			||||||
 | 
								expectValidated   bool
 | 
				
			||||||
			expectedLockState TestLock
 | 
								expectedLockState TestLock
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -469,6 +470,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
		DescribeTable("with a session",
 | 
							DescribeTable("with a session",
 | 
				
			||||||
			func(in refreshSessionIfNeededTableInput) {
 | 
								func(in refreshSessionIfNeededTableInput) {
 | 
				
			||||||
				refreshed := false
 | 
									refreshed := false
 | 
				
			||||||
 | 
									validated := false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				store := &fakeSessionStore{}
 | 
									store := &fakeSessionStore{}
 | 
				
			||||||
				if in.sessionStored {
 | 
									if in.sessionStored {
 | 
				
			||||||
| 
						 | 
					@ -496,6 +498,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
					sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
 | 
										sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool {
 | 
				
			||||||
 | 
											validated = true
 | 
				
			||||||
						return ss.AccessToken != "Invalid"
 | 
											return ss.AccessToken != "Invalid"
 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
| 
						 | 
					@ -508,6 +511,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
					Expect(err).ToNot(HaveOccurred())
 | 
										Expect(err).ToNot(HaveOccurred())
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				Expect(refreshed).To(Equal(in.expectRefreshed))
 | 
									Expect(refreshed).To(Equal(in.expectRefreshed))
 | 
				
			||||||
 | 
									Expect(validated).To(Equal(in.expectValidated))
 | 
				
			||||||
				testLock, ok := in.session.Lock.(*TestLock)
 | 
									testLock, ok := in.session.Lock.(*TestLock)
 | 
				
			||||||
				Expect(ok).To(Equal(true))
 | 
									Expect(ok).To(Equal(true))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -522,6 +526,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				expectedErr:       nil,
 | 
									expectedErr:       nil,
 | 
				
			||||||
				expectRefreshed:   false,
 | 
									expectRefreshed:   false,
 | 
				
			||||||
 | 
									expectValidated:   false,
 | 
				
			||||||
				expectedLockState: TestLock{},
 | 
									expectedLockState: TestLock{},
 | 
				
			||||||
			}),
 | 
								}),
 | 
				
			||||||
			Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{
 | 
								Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{
 | 
				
			||||||
| 
						 | 
					@ -533,6 +538,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				expectedErr:       nil,
 | 
									expectedErr:       nil,
 | 
				
			||||||
				expectRefreshed:   false,
 | 
									expectRefreshed:   false,
 | 
				
			||||||
 | 
									expectValidated:   false,
 | 
				
			||||||
				expectedLockState: TestLock{},
 | 
									expectedLockState: TestLock{},
 | 
				
			||||||
			}),
 | 
								}),
 | 
				
			||||||
			Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{
 | 
								Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{
 | 
				
			||||||
| 
						 | 
					@ -544,6 +550,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				expectedErr:       nil,
 | 
									expectedErr:       nil,
 | 
				
			||||||
				expectRefreshed:   false,
 | 
									expectRefreshed:   false,
 | 
				
			||||||
 | 
									expectValidated:   false,
 | 
				
			||||||
				expectedLockState: TestLock{},
 | 
									expectedLockState: TestLock{},
 | 
				
			||||||
			}),
 | 
								}),
 | 
				
			||||||
			Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{
 | 
								Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{
 | 
				
			||||||
| 
						 | 
					@ -555,6 +562,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				expectedErr:     nil,
 | 
									expectedErr:     nil,
 | 
				
			||||||
				expectRefreshed: true,
 | 
									expectRefreshed: true,
 | 
				
			||||||
 | 
									expectValidated: true,
 | 
				
			||||||
				expectedLockState: TestLock{
 | 
									expectedLockState: TestLock{
 | 
				
			||||||
					Locked:      false,
 | 
										Locked:      false,
 | 
				
			||||||
					WasObtained: true,
 | 
										WasObtained: true,
 | 
				
			||||||
| 
						 | 
					@ -574,6 +582,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
				sessionStored:   true,
 | 
									sessionStored:   true,
 | 
				
			||||||
				expectedErr:     nil,
 | 
									expectedErr:     nil,
 | 
				
			||||||
				expectRefreshed: false,
 | 
									expectRefreshed: false,
 | 
				
			||||||
 | 
									expectValidated: true,
 | 
				
			||||||
				expectedLockState: TestLock{
 | 
									expectedLockState: TestLock{
 | 
				
			||||||
					Locked:      false,
 | 
										Locked:      false,
 | 
				
			||||||
					PeekedCount: 2,
 | 
										PeekedCount: 2,
 | 
				
			||||||
| 
						 | 
					@ -590,6 +599,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				expectedErr:     nil,
 | 
									expectedErr:     nil,
 | 
				
			||||||
				expectRefreshed: false,
 | 
									expectRefreshed: false,
 | 
				
			||||||
 | 
									expectValidated: true,
 | 
				
			||||||
				expectedLockState: TestLock{
 | 
									expectedLockState: TestLock{
 | 
				
			||||||
					PeekedCount: 1,
 | 
										PeekedCount: 1,
 | 
				
			||||||
					ObtainError: errors.New("not able to obtain lock"),
 | 
										ObtainError: errors.New("not able to obtain lock"),
 | 
				
			||||||
| 
						 | 
					@ -605,6 +615,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				expectedErr:     nil,
 | 
									expectedErr:     nil,
 | 
				
			||||||
				expectRefreshed: true,
 | 
									expectRefreshed: true,
 | 
				
			||||||
 | 
									expectValidated: true,
 | 
				
			||||||
				expectedLockState: TestLock{
 | 
									expectedLockState: TestLock{
 | 
				
			||||||
					Locked:      false,
 | 
										Locked:      false,
 | 
				
			||||||
					WasObtained: true,
 | 
										WasObtained: true,
 | 
				
			||||||
| 
						 | 
					@ -621,6 +632,7 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				expectedErr:     nil,
 | 
									expectedErr:     nil,
 | 
				
			||||||
				expectRefreshed: true,
 | 
									expectRefreshed: true,
 | 
				
			||||||
 | 
									expectValidated: true,
 | 
				
			||||||
				expectedLockState: TestLock{
 | 
									expectedLockState: TestLock{
 | 
				
			||||||
					Locked:      false,
 | 
										Locked:      false,
 | 
				
			||||||
					WasObtained: true,
 | 
										WasObtained: true,
 | 
				
			||||||
| 
						 | 
					@ -637,8 +649,9 @@ var _ = Describe("Stored Session Suite", func() {
 | 
				
			||||||
					ExpiresOn:    &createdFuture,
 | 
										ExpiresOn:    &createdFuture,
 | 
				
			||||||
					Lock:         &TestLock{},
 | 
										Lock:         &TestLock{},
 | 
				
			||||||
				},
 | 
									},
 | 
				
			||||||
				expectedErr:     nil,
 | 
									expectedErr:     errors.New("session is invalid"),
 | 
				
			||||||
				expectRefreshed: true,
 | 
									expectRefreshed: true,
 | 
				
			||||||
 | 
									expectValidated: true,
 | 
				
			||||||
				expectedLockState: TestLock{
 | 
									expectedLockState: TestLock{
 | 
				
			||||||
					Locked:      false,
 | 
										Locked:      false,
 | 
				
			||||||
					WasObtained: true,
 | 
										WasObtained: true,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue