Changeset 313 in code for trunk/user.go


Ignore:
Timestamp:
Jun 4, 2020, 11:04:39 AM (5 years ago)
Author:
contact
Message:

Add network update command

The user.updateNetwork function is a bit involved because we need to
make sure that the upstream connection is closed before re-connecting
(would otherwise cause "Nick already used" errors) and that the
downstream connections' state is kept in sync.

References: https://todo.sr.ht/~emersion/soju/17

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/user.go

    r311 r313  
    273273}
    274274
     275func (u *user) getNetworkByID(id int64) *network {
     276        for _, net := range u.networks {
     277                if net.ID == id {
     278                        return net
     279                }
     280        }
     281        return nil
     282}
     283
    275284func (u *user) run() {
    276285        networks, err := u.srv.db.ListNetworks(u.Username)
     
    310319                        uc.network.lastError = nil
    311320                case eventUpstreamDisconnected:
    312                         uc := e.uc
    313 
    314                         uc.network.conn = nil
    315 
    316                         for _, ml := range uc.messageLoggers {
    317                                 if err := ml.Close(); err != nil {
    318                                         uc.logger.Printf("failed to close message logger: %v", err)
    319                                 }
    320                         }
    321 
    322                         uc.endPendingLISTs(true)
    323 
    324                         uc.forEachDownstream(func(dc *downstreamConn) {
    325                                 dc.updateSupportedCaps()
    326                         })
    327 
    328                         if uc.network.lastError == nil {
    329                                 uc.forEachDownstream(func(dc *downstreamConn) {
    330                                         sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
    331                                 })
    332                         }
     321                        u.handleUpstreamDisconnected(e.uc)
    333322                case eventUpstreamConnectionError:
    334323                        net := e.net
    335324
    336                         if net.lastError == nil || net.lastError.Error() != e.err.Error() {
     325                        stopped := false
     326                        select {
     327                        case <-net.stopped:
     328                                stopped = true
     329                        default:
     330                        }
     331
     332                        if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
    337333                                net.forEachDownstream(func(dc *downstreamConn) {
    338334                                        sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
     
    426422}
    427423
    428 func (u *user) createNetwork(net *Network) (*network, error) {
    429         if net.ID != 0 {
     424func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
     425        uc.network.conn = nil
     426
     427        for _, ml := range uc.messageLoggers {
     428                if err := ml.Close(); err != nil {
     429                        uc.logger.Printf("failed to close message logger: %v", err)
     430                }
     431        }
     432
     433        uc.endPendingLISTs(true)
     434
     435        uc.forEachDownstream(func(dc *downstreamConn) {
     436                dc.updateSupportedCaps()
     437        })
     438
     439        if uc.network.lastError == nil {
     440                uc.forEachDownstream(func(dc *downstreamConn) {
     441                        sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
     442                })
     443        }
     444}
     445
     446func (u *user) addNetwork(network *network) {
     447        u.networks = append(u.networks, network)
     448        go network.run()
     449}
     450
     451func (u *user) removeNetwork(network *network) {
     452        network.stop()
     453
     454        u.forEachDownstream(func(dc *downstreamConn) {
     455                if dc.network != nil && dc.network == network {
     456                        dc.Close()
     457                }
     458        })
     459
     460        for i, net := range u.networks {
     461                if net == network {
     462                        u.networks = append(u.networks[:i], u.networks[i+1:]...)
     463                        return
     464                }
     465        }
     466
     467        panic("tried to remove a non-existing network")
     468}
     469
     470func (u *user) createNetwork(record *Network) (*network, error) {
     471        if record.ID != 0 {
    430472                panic("tried creating an already-existing network")
    431473        }
    432474
    433         network := newNetwork(u, net, nil)
     475        network := newNetwork(u, record, nil)
    434476        err := u.srv.db.StoreNetwork(u.Username, &network.Network)
    435477        if err != nil {
     
    437479        }
    438480
    439         u.networks = append(u.networks, network)
    440 
    441         go network.run()
     481        u.addNetwork(network)
     482
    442483        return network, nil
    443484}
    444485
     486func (u *user) updateNetwork(record *Network) (*network, error) {
     487        if record.ID == 0 {
     488                panic("tried updating a new network")
     489        }
     490
     491        network := u.getNetworkByID(record.ID)
     492        if network == nil {
     493                panic("tried updating a non-existing network")
     494        }
     495
     496        if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
     497                return nil, err
     498        }
     499
     500        // Most network changes require us to re-connect to the upstream server
     501
     502        channels := make([]Channel, 0, len(network.channels))
     503        for _, ch := range network.channels {
     504                channels = append(channels, *ch)
     505        }
     506
     507        updatedNetwork := newNetwork(u, record, channels)
     508
     509        // If we're currently connected, disconnect and perform the necessary
     510        // bookkeeping
     511        if network.conn != nil {
     512                network.stop()
     513                // Note: this will set network.conn to nil
     514                u.handleUpstreamDisconnected(network.conn)
     515        }
     516
     517        // Patch downstream connections to use our fresh updated network
     518        u.forEachDownstream(func(dc *downstreamConn) {
     519                if dc.network != nil && dc.network == network {
     520                        dc.network = updatedNetwork
     521                }
     522        })
     523
     524        // We need to remove the network after patching downstream connections,
     525        // otherwise they'll get closed
     526        u.removeNetwork(network)
     527
     528        // This will re-connect to the upstream server
     529        u.addNetwork(updatedNetwork)
     530
     531        return updatedNetwork, nil
     532}
     533
    445534func (u *user) deleteNetwork(id int64) error {
    446         for i, net := range u.networks {
    447                 if net.ID != id {
    448                         continue
    449                 }
    450 
    451                 if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
    452                         return err
    453                 }
    454 
    455                 u.forEachDownstream(func(dc *downstreamConn) {
    456                         if dc.network != nil && dc.network == net {
    457                                 dc.Close()
    458                         }
    459                 })
    460 
    461                 net.stop()
    462                 u.networks = append(u.networks[:i], u.networks[i+1:]...)
    463                 return nil
    464         }
    465 
    466         panic("tried deleting a non-existing network")
     535        network := u.getNetworkByID(id)
     536        if network == nil {
     537                panic("tried deleting a non-existing network")
     538        }
     539
     540        if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
     541                return err
     542        }
     543
     544        u.removeNetwork(network)
     545        return nil
    467546}
    468547
Note: See TracChangeset for help on using the changeset viewer.