Changeset 168 in code for trunk


Ignore:
Timestamp:
Mar 27, 2020, 6:17:58 PM (5 years ago)
Author:
contact
Message:

Nuke user.lock

Split user.register into two functions, one to make sure the user is
authenticated, the other to send our current state. This allows to get
rid of data races by doing the second part in the user goroutine.

Closes: https://todo.sr.ht/~emersion/soju/22

Location:
trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/downstream.go

    r167 r168  
    7272        username    string
    7373        rawUsername string
     74        networkName string
    7475        realname    string
    7576        hostname    string
     
    583584}
    584585
    585 func (dc *downstreamConn) setNetwork(networkName string) error {
    586         if networkName == "" {
     586func (dc *downstreamConn) authenticate(username, password string) error {
     587        username, networkName := unmarshalUsername(username)
     588
     589        u := dc.srv.getUser(username)
     590        if u == nil {
     591                dc.logger.Printf("failed authentication for %q: unknown username", username)
     592                return errAuthFailed
     593        }
     594
     595        err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
     596        if err != nil {
     597                dc.logger.Printf("failed authentication for %q: %v", username, err)
     598                return errAuthFailed
     599        }
     600
     601        dc.user = u
     602        dc.networkName = networkName
     603        return nil
     604}
     605
     606func (dc *downstreamConn) register() error {
     607        if dc.registered {
     608                return fmt.Errorf("tried to register twice")
     609        }
     610
     611        password := dc.password
     612        dc.password = ""
     613        if dc.user == nil {
     614                if err := dc.authenticate(dc.rawUsername, password); err != nil {
     615                        return err
     616                }
     617        }
     618
     619        if dc.networkName == "" {
     620                _, dc.networkName = unmarshalUsername(dc.rawUsername)
     621        }
     622
     623        dc.registered = true
     624        dc.username = dc.user.Username
     625        dc.logger.Printf("registration complete for user %q", dc.username)
     626        return nil
     627}
     628
     629func (dc *downstreamConn) loadNetwork() error {
     630        if dc.networkName == "" {
    587631                return nil
    588632        }
    589633
    590         network := dc.user.getNetwork(networkName)
     634        network := dc.user.getNetwork(dc.networkName)
    591635        if network == nil {
    592                 addr := networkName
     636                addr := dc.networkName
    593637                if !strings.ContainsRune(addr, ':') {
    594638                        addr = addr + ":6697"
     
    600644                        return ircError{&irc.Message{
    601645                                Command: irc.ERR_PASSWDMISMATCH,
    602                                 Params:  []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
     646                                Params:  []string{"*", fmt.Sprintf("Failed to connect to %q", dc.networkName)},
    603647                        }}
    604648                }
    605649
    606                 dc.logger.Printf("auto-saving network %q", networkName)
     650                dc.logger.Printf("auto-saving network %q", dc.networkName)
    607651                var err error
    608652                network, err = dc.user.createNetwork(&Network{
    609                         Addr: networkName,
     653                        Addr: dc.networkName,
    610654                        Nick: dc.nick,
    611655                })
     
    619663}
    620664
    621 func (dc *downstreamConn) authenticate(username, password string) error {
    622         username, networkName := unmarshalUsername(username)
    623 
    624         u := dc.srv.getUser(username)
    625         if u == nil {
    626                 dc.logger.Printf("failed authentication for %q: unknown username", username)
    627                 return errAuthFailed
    628         }
    629 
    630         err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
    631         if err != nil {
    632                 dc.logger.Printf("failed authentication for %q: %v", username, err)
    633                 return errAuthFailed
    634         }
    635 
    636         dc.user = u
    637 
    638         return dc.setNetwork(networkName)
    639 }
    640 
    641 func (dc *downstreamConn) register() error {
    642         password := dc.password
    643         dc.password = ""
    644         if dc.user == nil {
    645                 if err := dc.authenticate(dc.rawUsername, password); err != nil {
    646                         return err
    647                 }
    648         } else if dc.network == nil {
    649                 _, networkName := unmarshalUsername(dc.rawUsername)
    650                 if err := dc.setNetwork(networkName); err != nil {
    651                         return err
    652                 }
    653         }
    654 
    655         dc.registered = true
    656         dc.username = dc.user.Username
    657         dc.logger.Printf("registration complete for user %q", dc.username)
    658 
    659         dc.user.lock.Lock()
     665func (dc *downstreamConn) welcome() error {
     666        if dc.user == nil || !dc.registered {
     667                panic("tried to welcome an unregistered connection")
     668        }
     669
     670        // TODO: doing this might take some time. We should do it in dc.register
     671        // instead, but we'll potentially be adding a new network and this must be
     672        // done in the user goroutine.
     673        if err := dc.loadNetwork(); err != nil {
     674                return err
     675        }
     676
    660677        firstDownstream := len(dc.user.downstreamConns) == 0
    661         dc.user.lock.Unlock()
    662678
    663679        dc.SendMessage(&irc.Message{
  • trunk/user.go

    r167 r168  
    9292        events chan event
    9393
    94         lock            sync.Mutex
    9594        networks        []*network
    9695        downstreamConns []*downstreamConn
     
    106105
    107106func (u *user) forEachNetwork(f func(*network)) {
    108         u.lock.Lock()
    109107        for _, network := range u.networks {
    110108                f(network)
    111109        }
    112         u.lock.Unlock()
    113110}
    114111
    115112func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
    116         u.lock.Lock()
    117113        for _, network := range u.networks {
    118114                uc := network.upstream()
     
    122118                f(uc)
    123119        }
    124         u.lock.Unlock()
    125120}
    126121
    127122func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
    128         u.lock.Lock()
    129123        for _, dc := range u.downstreamConns {
    130124                f(dc)
    131125        }
    132         u.lock.Unlock()
    133126}
    134127
     
    149142        }
    150143
    151         u.lock.Lock()
    152144        for _, record := range networks {
    153145                network := newNetwork(u, &record)
     
    156148                go network.run()
    157149        }
    158         u.lock.Unlock()
    159150
    160151        for e := range u.events {
     
    171162                case eventDownstreamConnected:
    172163                        dc := e.dc
    173                         u.lock.Lock()
     164
     165                        if err := dc.welcome(); err != nil {
     166                                dc.logger.Printf("failed to handle new registered connection: %v", err)
     167                                break
     168                        }
     169
    174170                        u.downstreamConns = append(u.downstreamConns, dc)
    175                         u.lock.Unlock()
    176171                case eventDownstreamDisconnected:
    177172                        dc := e.dc
    178                         u.lock.Lock()
    179173                        for i := range u.downstreamConns {
    180174                                if u.downstreamConns[i] == dc {
     
    183177                                }
    184178                        }
    185                         u.lock.Unlock()
    186179                case eventDownstreamMessage:
    187180                        msg, dc := e.msg, e.dc
     
    221214        })
    222215
    223         u.lock.Lock()
    224216        u.networks = append(u.networks, network)
    225         u.lock.Unlock()
    226217
    227218        go network.run()
Note: See TracChangeset for help on using the changeset viewer.