123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- /*
- *
- * Copyright 2016 gRPC authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
- package credentials
- import (
- "context"
- "crypto/tls"
- "net"
- "testing"
- "google.golang.org/grpc/testdata"
- )
- func TestTLSOverrideServerName(t *testing.T) {
- expectedServerName := "server.name"
- c := NewTLS(nil)
- c.OverrideServerName(expectedServerName)
- if c.Info().ServerName != expectedServerName {
- t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
- }
- }
- func TestTLSClone(t *testing.T) {
- expectedServerName := "server.name"
- c := NewTLS(nil)
- c.OverrideServerName(expectedServerName)
- cc := c.Clone()
- if cc.Info().ServerName != expectedServerName {
- t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
- }
- cc.OverrideServerName("")
- if c.Info().ServerName != expectedServerName {
- t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
- }
- }
- type serverHandshake func(net.Conn) (AuthInfo, error)
- func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
- done := make(chan AuthInfo, 1)
- lis := launchServer(t, tlsServerHandshake, done)
- defer lis.Close()
- lisAddr := lis.Addr().String()
- clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
- // wait until server sends serverAuthInfo or fails.
- serverAuthInfo, ok := <-done
- if !ok {
- t.Fatalf("Error at server-side")
- }
- if !compare(clientAuthInfo, serverAuthInfo) {
- t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
- }
- }
- func TestServerHandshakeReturnsAuthInfo(t *testing.T) {
- done := make(chan AuthInfo, 1)
- lis := launchServer(t, gRPCServerHandshake, done)
- defer lis.Close()
- clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
- // wait until server sends serverAuthInfo or fails.
- serverAuthInfo, ok := <-done
- if !ok {
- t.Fatalf("Error at server-side")
- }
- if !compare(clientAuthInfo, serverAuthInfo) {
- t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
- }
- }
- func TestServerAndClientHandshake(t *testing.T) {
- done := make(chan AuthInfo, 1)
- lis := launchServer(t, gRPCServerHandshake, done)
- defer lis.Close()
- clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
- // wait until server sends serverAuthInfo or fails.
- serverAuthInfo, ok := <-done
- if !ok {
- t.Fatalf("Error at server-side")
- }
- if !compare(clientAuthInfo, serverAuthInfo) {
- t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
- }
- }
- func compare(a1, a2 AuthInfo) bool {
- if a1.AuthType() != a2.AuthType() {
- return false
- }
- switch a1.AuthType() {
- case "tls":
- state1 := a1.(TLSInfo).State
- state2 := a2.(TLSInfo).State
- if state1.Version == state2.Version &&
- state1.HandshakeComplete == state2.HandshakeComplete &&
- state1.CipherSuite == state2.CipherSuite &&
- state1.NegotiatedProtocol == state2.NegotiatedProtocol {
- return true
- }
- return false
- default:
- return false
- }
- }
- func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
- lis, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- t.Fatalf("Failed to listen: %v", err)
- }
- go serverHandle(t, hs, done, lis)
- return lis
- }
- // Is run in a separate goroutine.
- func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) {
- serverRawConn, err := lis.Accept()
- if err != nil {
- t.Errorf("Server failed to accept connection: %v", err)
- close(done)
- return
- }
- serverAuthInfo, err := hs(serverRawConn)
- if err != nil {
- t.Errorf("Server failed while handshake. Error: %v", err)
- serverRawConn.Close()
- close(done)
- return
- }
- done <- serverAuthInfo
- }
- func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo {
- conn, err := net.Dial("tcp", lisAddr)
- if err != nil {
- t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
- }
- defer conn.Close()
- clientAuthInfo, err := hs(conn, lisAddr)
- if err != nil {
- t.Fatalf("Error on client while handshake. Error: %v", err)
- }
- return clientAuthInfo
- }
- // Server handshake implementation in gRPC.
- func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) {
- serverTLS, err := NewServerTLSFromFile(testdata.Path("server1.pem"), testdata.Path("server1.key"))
- if err != nil {
- return nil, err
- }
- _, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
- if err != nil {
- return nil, err
- }
- return serverAuthInfo, nil
- }
- // Client handshake implementation in gRPC.
- func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) {
- clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
- _, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn)
- if err != nil {
- return nil, err
- }
- return authInfo, nil
- }
- func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
- cert, err := tls.LoadX509KeyPair(testdata.Path("server1.pem"), testdata.Path("server1.key"))
- if err != nil {
- return nil, err
- }
- serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
- serverConn := tls.Server(conn, serverTLSConfig)
- err = serverConn.Handshake()
- if err != nil {
- return nil, err
- }
- return TLSInfo{State: serverConn.ConnectionState()}, nil
- }
- func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
- clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
- clientConn := tls.Client(conn, clientTLSConfig)
- if err := clientConn.Handshake(); err != nil {
- return nil, err
- }
- return TLSInfo{State: clientConn.ConnectionState()}, nil
- }
|