diff options
Diffstat (limited to 'cmd/dendrite-demo-yggdrasil/yggconn/session.go')
-rw-r--r-- | cmd/dendrite-demo-yggdrasil/yggconn/session.go | 194 |
1 files changed, 149 insertions, 45 deletions
diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/session.go b/cmd/dendrite-demo-yggdrasil/yggconn/session.go index 0d231f6d..0cf524d9 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/session.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/session.go @@ -31,8 +31,32 @@ import ( "github.com/lucas-clemente/quic-go" "github.com/yggdrasil-network/yggdrasil-go/src/crypto" + "github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil" ) +type session struct { + node *Node + session quic.Session + address string + context context.Context + cancel context.CancelFunc +} + +func (n *Node) newSession(sess quic.Session, address string) *session { + ctx, cancel := context.WithCancel(context.TODO()) + return &session{ + node: n, + session: sess, + address: address, + context: ctx, + cancel: cancel, + } +} + +func (s *session) kill() { + s.cancel() +} + func (n *Node) listenFromYgg() { var err error n.listener, err = quic.Listen( @@ -55,22 +79,31 @@ func (n *Node) listenFromYgg() { _ = session.CloseWithError(0, "expected a peer certificate") continue } - address := session.ConnectionState().PeerCertificates[0].Subject.CommonName + address := session.ConnectionState().PeerCertificates[0].DNSNames[0] n.log.Infoln("Accepted connection from", address) - go n.listenFromQUIC(session, address) + go n.newSession(session, address).listenFromQUIC() + go n.sessionFunc(address) } } -func (n *Node) listenFromQUIC(session quic.Session, address string) { - n.sessions.Store(address, session) - defer n.sessions.Delete(address) +func (s *session) listenFromQUIC() { + if existing, ok := s.node.sessions.Load(s.address); ok { + if existingSession, ok := existing.(*session); ok { + fmt.Println("Killing existing session to replace", s.address) + existingSession.kill() + } + } + s.node.sessionCount.Inc() + s.node.sessions.Store(s.address, s) + defer s.node.sessions.Delete(s.address) + defer s.node.sessionCount.Dec() for { - st, err := session.AcceptStream(context.TODO()) + st, err := s.session.AcceptStream(s.context) if err != nil { - n.log.Println("session.AcceptStream:", err) + s.node.log.Println("session.AcceptStream:", err) return } - n.incoming <- QUICStream{st, session} + s.node.incoming <- QUICStream{st, s.session} } } @@ -95,53 +128,124 @@ func (n *Node) Dial(network, address string) (net.Conn, error) { } // Implements http.Transport.DialContext +// nolint:gocyclo func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) { s, ok1 := n.sessions.Load(address) - session, ok2 := s.(quic.Session) - if !ok1 || !ok2 || (ok1 && ok2 && session.ConnectionState().HandshakeComplete) { - dest, err := hex.DecodeString(address) - if err != nil { - return nil, err - } - if len(dest) != crypto.BoxPubKeyLen { - return nil, errors.New("invalid key length supplied") - } - var pubKey crypto.BoxPubKey - copy(pubKey[:], dest) - nodeID := crypto.GetNodeID(&pubKey) - nodeMask := &crypto.NodeID{} - for i := range nodeMask { - nodeMask[i] = 0xFF + session, ok2 := s.(*session) + if !ok1 || !ok2 { + // First of all, check if we think we know the coords of this + // node. If we do then we'll try to dial to it directly. This + // will either succeed or fail. + if v, ok := n.coords.Load(address); ok { + coords, ok := v.(yggdrasil.Coords) + if !ok { + n.coords.Delete(address) + return nil, errors.New("should have found yggdrasil.Coords but didn't") + } + n.log.Infof("Coords %s for %q cached, trying to dial", coords.String(), address) + var err error + // We think we know the coords. Try to dial the node. + if session, err = n.tryDial(address, coords); err != nil { + // We thought we knew the coords but it didn't result + // in a successful dial. Nuke them from the cache. + n.coords.Delete(address) + n.log.Infof("Cached coords %s for %q failed", coords.String(), address) + } } - fmt.Println("Resolving coords") - coords, err := n.core.Resolve(nodeID, nodeMask) - if err != nil { - return nil, fmt.Errorf("n.core.Resolve: %w", err) - } - fmt.Println("Found coords:", coords) - fmt.Println("Dialling") - - session, err = quic.Dial( - n.core, // yggdrasil.PacketConn - coords, // dial address - address, // dial SNI - n.tlsConfig, // TLS config - n.quicConfig, // QUIC config - ) - if err != nil { - n.log.Println("n.dialer.DialContext:", err) - return nil, err + // We either don't know the coords for the node, or we failed + // to dial it before, in which case try to resolve the coords. + if _, ok := n.coords.Load(address); !ok { + var coords yggdrasil.Coords + var err error + + // First look and see if the node is something that we already + // know about from our direct switch peers. + for _, peer := range n.core.GetSwitchPeers() { + if peer.PublicKey.String() == address { + coords = peer.Coords + n.log.Infof("%q is a direct peer, coords are %s", address, coords.String()) + n.coords.Store(address, coords) + break + } + } + + // If it isn' a node that we know directly then try to search + // the network. + if coords == nil { + n.log.Infof("Searching for coords for %q", address) + dest, derr := hex.DecodeString(address) + if derr != nil { + return nil, derr + } + if len(dest) != crypto.BoxPubKeyLen { + return nil, errors.New("invalid key length supplied") + } + var pubKey crypto.BoxPubKey + copy(pubKey[:], dest) + nodeID := crypto.GetNodeID(&pubKey) + nodeMask := &crypto.NodeID{} + for i := range nodeMask { + nodeMask[i] = 0xFF + } + + fmt.Println("Resolving coords") + coords, err = n.core.Resolve(nodeID, nodeMask) + if err != nil { + return nil, fmt.Errorf("n.core.Resolve: %w", err) + } + fmt.Println("Found coords:", coords) + n.coords.Store(address, coords) + } + + // We now know the coords in theory. Let's try dialling the + // node again. + if session, err = n.tryDial(address, coords); err != nil { + return nil, fmt.Errorf("n.tryDial: %w", err) + } } - fmt.Println("Dial OK") - go n.listenFromQUIC(session, address) } - st, err := session.OpenStream() + + if session == nil { + return nil, fmt.Errorf("should have found session but didn't") + } + + st, err := session.session.OpenStream() if err != nil { n.log.Println("session.OpenStream:", err) + _ = session.session.CloseWithError(0, "expected to be able to open session") return nil, err } - return QUICStream{st, session}, nil + return QUICStream{st, session.session}, nil +} + +func (n *Node) tryDial(address string, coords yggdrasil.Coords) (*session, error) { + quicSession, err := quic.Dial( + n.core, // yggdrasil.PacketConn + coords, // dial address + address, // dial SNI + n.tlsConfig, // TLS config + n.quicConfig, // QUIC config + ) + if err != nil { + return nil, err + } + if len(quicSession.ConnectionState().PeerCertificates) != 1 { + _ = quicSession.CloseWithError(0, "expected a peer certificate") + return nil, errors.New("didn't receive a peer certificate") + } + if len(quicSession.ConnectionState().PeerCertificates[0].DNSNames) != 1 { + _ = quicSession.CloseWithError(0, "expected a DNS name") + return nil, errors.New("didn't receive a DNS name") + } + if gotAddress := quicSession.ConnectionState().PeerCertificates[0].DNSNames[0]; address != gotAddress { + _ = quicSession.CloseWithError(0, "you aren't the host I was hoping for") + return nil, fmt.Errorf("expected %q but dialled %q", address, gotAddress) + } + session := n.newSession(quicSession, address) + go session.listenFromQUIC() + go n.sessionFunc(address) + return session, nil } func (n *Node) generateTLSConfig() *tls.Config { |