Changeset 313 in code for trunk


Ignore:
Timestamp:
Jun 4, 2020, 11:04:39 AM (5 years ago)
Author:
contact
Message:

Add network update command

The user.updateNetwork function is a bit involved because we need to
make sure that the upstream connection is closed before re-connecting
(would otherwise cause "Nick already used" errors) and that the
downstream connections' state is kept in sync.

References: https://todo.sr.ht/~emersion/soju/17

Location:
trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/service.go

    r307 r313  
    119119                        children: serviceCommandSet{
    120120                                "create": {
    121                                         usage:  "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [[-connect-command command] ...]",
     121                                        usage:  "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
    122122                                        desc:   "add a new network",
    123123                                        handle: handleServiceCreateNetwork,
     
    126126                                        desc:   "show a list of saved networks and their current status",
    127127                                        handle: handleServiceNetworkStatus,
     128                                },
     129                                "update": {
     130                                        usage: "[-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-connect-command command]...",
     131                                        desc:  "update a network",
     132                                        handle: handleServiceNetworkUpdate,
    128133                                },
    129134                                "delete": {
     
    339344}
    340345
    341 type stringSliceVar []string
    342 
    343 func (v *stringSliceVar) String() string {
     346type stringSliceFlag []string
     347
     348func (v *stringSliceFlag) String() string {
    344349        return fmt.Sprint([]string(*v))
    345350}
    346351
    347 func (v *stringSliceVar) Set(s string) error {
     352func (v *stringSliceFlag) Set(s string) error {
    348353        *v = append(*v, s)
    349354        return nil
    350355}
    351356
     357// stringPtrFlag is a flag value populating a string pointer. This allows to
     358// disambiguate between a flag that hasn't been set and a flag that has been
     359// set to an empty string.
     360type stringPtrFlag struct {
     361        ptr **string
     362}
     363
     364func (f stringPtrFlag) String() string {
     365        if *f.ptr == nil {
     366                return ""
     367        }
     368        return **f.ptr
     369}
     370
     371func (f stringPtrFlag) Set(s string) error {
     372        *f.ptr = &s
     373        return nil
     374}
     375
     376type networkFlagSet struct {
     377        *flag.FlagSet
     378        Addr, Name, Nick, Username, Pass, Realname *string
     379        ConnectCommands []string
     380}
     381
     382func newNetworkFlagSet() *networkFlagSet {
     383        fs := &networkFlagSet{FlagSet: newFlagSet()}
     384        fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
     385        fs.Var(stringPtrFlag{&fs.Name}, "name", "")
     386        fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
     387        fs.Var(stringPtrFlag{&fs.Username}, "username", "")
     388        fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
     389        fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
     390        fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
     391        return fs
     392}
     393
     394func (fs *networkFlagSet) update(network *Network) error {
     395        if fs.Addr != nil {
     396                if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
     397                        scheme := addrParts[0]
     398                        switch scheme {
     399                        case "ircs", "irc+insecure":
     400                        default:
     401                                return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
     402                        }
     403                }
     404                network.Addr = *fs.Addr
     405        }
     406        if fs.Name != nil {
     407                network.Name = *fs.Name
     408        }
     409        if fs.Nick != nil {
     410                network.Nick = *fs.Nick
     411        }
     412        if fs.Username != nil {
     413                network.Username = *fs.Username
     414        }
     415        if fs.Pass != nil {
     416                network.Pass = *fs.Pass
     417        }
     418        if fs.Realname != nil {
     419                network.Realname = *fs.Realname
     420        }
     421        if fs.ConnectCommands != nil {
     422                if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
     423                        network.ConnectCommands = nil
     424                } else {
     425                        for _, command := range fs.ConnectCommands {
     426                                _, err := irc.ParseMessage(command)
     427                                if err != nil {
     428                                        return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
     429                                }
     430                        }
     431                        network.ConnectCommands = fs.ConnectCommands
     432                }
     433        }
     434        return nil
     435}
     436
    352437func handleServiceCreateNetwork(dc *downstreamConn, params []string) error {
    353         fs := newFlagSet()
    354         addr := fs.String("addr", "", "")
    355         name := fs.String("name", "", "")
    356         username := fs.String("username", "", "")
    357         pass := fs.String("pass", "", "")
    358         realname := fs.String("realname", "", "")
    359         nick := fs.String("nick", "", "")
    360         var connectCommands stringSliceVar
    361         fs.Var(&connectCommands, "connect-command", "")
    362 
     438        fs := newNetworkFlagSet()
    363439        if err := fs.Parse(params); err != nil {
    364440                return err
    365441        }
    366         if *addr == "" {
     442        if fs.Addr == nil {
    367443                return fmt.Errorf("flag -addr is required")
    368444        }
    369445
    370         if addrParts := strings.SplitN(*addr, "://", 2); len(addrParts) == 2 {
    371                 scheme := addrParts[0]
    372                 switch scheme {
    373                 case "ircs", "irc+insecure":
    374                 default:
    375                         return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc+insecure)", scheme)
    376                 }
    377         }
    378 
    379         for _, command := range connectCommands {
    380                 _, err := irc.ParseMessage(command)
    381                 if err != nil {
    382                         return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
    383                 }
    384         }
    385 
    386         if *nick == "" {
    387                 *nick = dc.nick
    388         }
    389 
    390         var err error
    391         network, err := dc.user.createNetwork(&Network{
    392                 Addr:            *addr,
    393                 Name:            *name,
    394                 Username:        *username,
    395                 Pass:            *pass,
    396                 Realname:        *realname,
    397                 Nick:            *nick,
    398                 ConnectCommands: connectCommands,
    399         })
     446        record := &Network{
     447                Addr: *fs.Addr,
     448                Nick: dc.nick,
     449        }
     450        if err := fs.update(record); err != nil {
     451                return err
     452        }
     453
     454        network, err := dc.user.createNetwork(record)
    400455        if err != nil {
    401456                return fmt.Errorf("could not create network: %v", err)
     
    442497}
    443498
     499func handleServiceNetworkUpdate(dc *downstreamConn, params []string) error {
     500        if len(params) < 1 {
     501                return fmt.Errorf("expected exactly one argument")
     502        }
     503
     504        fs := newNetworkFlagSet()
     505        if err := fs.Parse(params[1:]); err != nil {
     506                return err
     507        }
     508
     509        net := dc.user.getNetwork(params[0])
     510        if net == nil {
     511                return fmt.Errorf("unknown network %q", params[0])
     512        }
     513
     514        record := net.Network // copy network record because we'll mutate it
     515        if err := fs.update(&record); err != nil {
     516                return err
     517        }
     518
     519        network, err := dc.user.updateNetwork(&record)
     520        if err != nil {
     521                return fmt.Errorf("could not update network: %v", err)
     522        }
     523
     524        sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
     525        return nil
     526}
     527
    444528func handleServiceNetworkDelete(dc *downstreamConn, params []string) error {
    445529        if len(params) != 1 {
  • trunk/user.go

    r311 r313  
    273273}
    274274
     275func (u *user) getNetworkByID(id int64) *network {
     276        for _, net := range u.networks {
     277                if net.ID == id {
     278                        return net
     279                }
     280        }
     281        return nil
     282}
     283
    275284func (u *user) run() {
    276285        networks, err := u.srv.db.ListNetworks(u.Username)
     
    310319                        uc.network.lastError = nil
    311320                case eventUpstreamDisconnected:
    312                         uc := e.uc
    313 
    314                         uc.network.conn = nil
    315 
    316                         for _, ml := range uc.messageLoggers {
    317                                 if err := ml.Close(); err != nil {
    318                                         uc.logger.Printf("failed to close message logger: %v", err)
    319                                 }
    320                         }
    321 
    322                         uc.endPendingLISTs(true)
    323 
    324                         uc.forEachDownstream(func(dc *downstreamConn) {
    325                                 dc.updateSupportedCaps()
    326                         })
    327 
    328                         if uc.network.lastError == nil {
    329                                 uc.forEachDownstream(func(dc *downstreamConn) {
    330                                         sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
    331                                 })
    332                         }
     321                        u.handleUpstreamDisconnected(e.uc)
    333322                case eventUpstreamConnectionError:
    334323                        net := e.net
    335324
    336                         if net.lastError == nil || net.lastError.Error() != e.err.Error() {
     325                        stopped := false
     326                        select {
     327                        case <-net.stopped:
     328                                stopped = true
     329                        default:
     330                        }
     331
     332                        if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
    337333                                net.forEachDownstream(func(dc *downstreamConn) {
    338334                                        sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
     
    426422}
    427423
    428 func (u *user) createNetwork(net *Network) (*network, error) {
    429         if net.ID != 0 {
     424func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
     425        uc.network.conn = nil
     426
     427        for _, ml := range uc.messageLoggers {
     428                if err := ml.Close(); err != nil {
     429                        uc.logger.Printf("failed to close message logger: %v", err)
     430                }
     431        }
     432
     433        uc.endPendingLISTs(true)
     434
     435        uc.forEachDownstream(func(dc *downstreamConn) {
     436                dc.updateSupportedCaps()
     437        })
     438
     439        if uc.network.lastError == nil {
     440                uc.forEachDownstream(func(dc *downstreamConn) {
     441                        sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
     442                })
     443        }
     444}
     445
     446func (u *user) addNetwork(network *network) {
     447        u.networks = append(u.networks, network)
     448        go network.run()
     449}
     450
     451func (u *user) removeNetwork(network *network) {
     452        network.stop()
     453
     454        u.forEachDownstream(func(dc *downstreamConn) {
     455                if dc.network != nil && dc.network == network {
     456                        dc.Close()
     457                }
     458        })
     459
     460        for i, net := range u.networks {
     461                if net == network {
     462                        u.networks = append(u.networks[:i], u.networks[i+1:]...)
     463                        return
     464                }
     465        }
     466
     467        panic("tried to remove a non-existing network")
     468}
     469
     470func (u *user) createNetwork(record *Network) (*network, error) {
     471        if record.ID != 0 {
    430472                panic("tried creating an already-existing network")
    431473        }
    432474
    433         network := newNetwork(u, net, nil)
     475        network := newNetwork(u, record, nil)
    434476        err := u.srv.db.StoreNetwork(u.Username, &network.Network)
    435477        if err != nil {
     
    437479        }
    438480
    439         u.networks = append(u.networks, network)
    440 
    441         go network.run()
     481        u.addNetwork(network)
     482
    442483        return network, nil
    443484}
    444485
     486func (u *user) updateNetwork(record *Network) (*network, error) {
     487        if record.ID == 0 {
     488                panic("tried updating a new network")
     489        }
     490
     491        network := u.getNetworkByID(record.ID)
     492        if network == nil {
     493                panic("tried updating a non-existing network")
     494        }
     495
     496        if err := u.srv.db.StoreNetwork(u.Username, record); err != nil {
     497                return nil, err
     498        }
     499
     500        // Most network changes require us to re-connect to the upstream server
     501
     502        channels := make([]Channel, 0, len(network.channels))
     503        for _, ch := range network.channels {
     504                channels = append(channels, *ch)
     505        }
     506
     507        updatedNetwork := newNetwork(u, record, channels)
     508
     509        // If we're currently connected, disconnect and perform the necessary
     510        // bookkeeping
     511        if network.conn != nil {
     512                network.stop()
     513                // Note: this will set network.conn to nil
     514                u.handleUpstreamDisconnected(network.conn)
     515        }
     516
     517        // Patch downstream connections to use our fresh updated network
     518        u.forEachDownstream(func(dc *downstreamConn) {
     519                if dc.network != nil && dc.network == network {
     520                        dc.network = updatedNetwork
     521                }
     522        })
     523
     524        // We need to remove the network after patching downstream connections,
     525        // otherwise they'll get closed
     526        u.removeNetwork(network)
     527
     528        // This will re-connect to the upstream server
     529        u.addNetwork(updatedNetwork)
     530
     531        return updatedNetwork, nil
     532}
     533
    445534func (u *user) deleteNetwork(id int64) error {
    446         for i, net := range u.networks {
    447                 if net.ID != id {
    448                         continue
    449                 }
    450 
    451                 if err := u.srv.db.DeleteNetwork(net.ID); err != nil {
    452                         return err
    453                 }
    454 
    455                 u.forEachDownstream(func(dc *downstreamConn) {
    456                         if dc.network != nil && dc.network == net {
    457                                 dc.Close()
    458                         }
    459                 })
    460 
    461                 net.stop()
    462                 u.networks = append(u.networks[:i], u.networks[i+1:]...)
    463                 return nil
    464         }
    465 
    466         panic("tried deleting a non-existing network")
     535        network := u.getNetworkByID(id)
     536        if network == nil {
     537                panic("tried deleting a non-existing network")
     538        }
     539
     540        if err := u.srv.db.DeleteNetwork(network.ID); err != nil {
     541                return err
     542        }
     543
     544        u.removeNetwork(network)
     545        return nil
    467546}
    468547
Note: See TracChangeset for help on using the changeset viewer.