Skip to content

Commit

Permalink
Merge pull request #181 from WalletConnect/feat/chain-switching
Browse files Browse the repository at this point in the history
feat: chain switching
  • Loading branch information
skibitsky authored Mar 19, 2024
2 parents 2a1aed2 + 72beae5 commit 2d45af3
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 138 deletions.
47 changes: 41 additions & 6 deletions Tests/WalletConnectSharp.Sign.Test/SignTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public static async Task<SessionStruct> TestConnectMethod(ISignClient clientA, I
},
Chains = new[]
{
"eip155:1"
"eip155:1", "eip155:10"
},
Events = new[]
{
Expand Down Expand Up @@ -149,7 +149,7 @@ public async Task TestRejectSession()
},
Chains = new[]
{
"eip155:1"
"eip155:1", "eip155:10"
},
Events = new[]
{
Expand Down Expand Up @@ -192,7 +192,7 @@ public async Task TestSessionRequestResponse()
},
Chains = new[]
{
"eip155:1"
"eip155:1", "eip155:10"
},
Events = new[]
{
Expand Down Expand Up @@ -294,7 +294,7 @@ public async Task TestTwoUniqueSessionRequestResponse()
},
Chains = new[]
{
"eip155:1"
"eip155:1", "eip155:10"
},
Events = new[]
{
Expand Down Expand Up @@ -555,14 +555,19 @@ public async Task TestAddressProviderDefaults()
var address = dappClient.AddressProvider.CurrentAddress();
Assert.Equal(testAddress, address.Address);
Assert.Equal("eip155:1", address.ChainId);
Assert.Equal("eip155:1", dappClient.AddressProvider.DefaultChain);
Assert.Equal("eip155:1", dappClient.AddressProvider.DefaultChainId);
Assert.Equal("eip155", dappClient.AddressProvider.DefaultNamespace);

address = walletClient.AddressProvider.CurrentAddress();
Assert.Equal(testAddress, address.Address);
Assert.Equal("eip155:1", address.ChainId);
Assert.Equal("eip155:1", dappClient.AddressProvider.DefaultChain);
Assert.Equal("eip155:1", dappClient.AddressProvider.DefaultChainId);
Assert.Equal("eip155", dappClient.AddressProvider.DefaultNamespace);

var allAddresses = dappClient.AddressProvider.AllAddresses("eip155").ToArray();
Assert.Single(allAddresses);
Assert.Equal(testAddress, allAddresses[0].Address);
Assert.Equal("eip155:1", allAddresses[0].ChainId);
}

[Fact, Trait("Category", "integration")]
Expand Down Expand Up @@ -591,5 +596,35 @@ public async Task TestAddressProviderDefaultsSaving()

await TestTwoUniqueSessionRequestResponseUsingAddressProviderDefaults();
}

[Fact] [Trait("Category", "integration")]
public async Task TestAddressProviderChainIdChange()
{
await _cryptoFixture.WaitForClientsReady();

_ = await TestConnectMethod(ClientA, ClientB);

const string badChainId = "invalid:invalid";
await Assert.ThrowsAsync<InvalidOperationException>(() => ClientA.AddressProvider.SetDefaultChainIdAsync(badChainId));

// Change the default chain id
const string newChainId = "eip155:10";
await ClientA.AddressProvider.SetDefaultChainIdAsync(newChainId);
Assert.Equal(newChainId, ClientA.AddressProvider.DefaultChainId);
}

[Fact] [Trait("Category", "integration")]
public async Task TestAddressProviderDisconnect()
{
await _cryptoFixture.WaitForClientsReady();

_ = await TestConnectMethod(ClientA, ClientB);

Assert.True(ClientA.AddressProvider.HasDefaultSession);

await ClientA.Disconnect();

Assert.False(ClientA.AddressProvider.HasDefaultSession);
}
}
}
187 changes: 100 additions & 87 deletions WalletConnectSharp.Sign/Controllers/AddressProvider.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
using Newtonsoft.Json;
using WalletConnectSharp.Common.Logging;
using WalletConnectSharp.Sign.Interfaces;
using WalletConnectSharp.Sign.Interfaces;
using WalletConnectSharp.Sign.Models;
using WalletConnectSharp.Sign.Models.Engine.Events;

namespace WalletConnectSharp.Sign.Controllers;

public class AddressProvider : IAddressProvider
{
private bool _disposed;

public struct DefaultData
{
public SessionStruct Session;
Expand Down Expand Up @@ -67,7 +67,7 @@ public string DefaultNamespace
}
}

public string DefaultChain
public string DefaultChainId
{
get
{
Expand All @@ -91,7 +91,7 @@ public AddressProvider(ISignClient client)
// set the first connected session to the default one
client.SessionConnected += ClientOnSessionConnected;
client.SessionDeleted += ClientOnSessionDeleted;
client.SessionUpdated += ClientOnSessionUpdated;
client.SessionUpdateRequest += ClientOnSessionUpdated;
client.SessionApproved += ClientOnSessionConnected;
}

Expand All @@ -115,142 +115,155 @@ public virtual async Task LoadDefaults()
DefaultsLoaded?.Invoke(this, new DefaultsLoadingEventArgs(_state));
}

private void ClientOnSessionUpdated(object sender, SessionEvent e)
private async void ClientOnSessionUpdated(object sender, SessionEvent e)
{
if (DefaultSession.Topic == e.Topic)
{
UpdateDefaultChainAndNamespace();
DefaultSession = Sessions.Get(e.Topic);
await UpdateDefaultChainIdAndNamespaceAsync();
}
}

private void ClientOnSessionDeleted(object sender, SessionEvent e)
private async void ClientOnSessionDeleted(object sender, SessionEvent e)
{
if (DefaultSession.Topic == e.Topic)
{
DefaultSession = default;
UpdateDefaultChainAndNamespace();
await UpdateDefaultChainIdAndNamespaceAsync();
}
}

private void ClientOnSessionConnected(object sender, SessionStruct e)
private async void ClientOnSessionConnected(object sender, SessionStruct e)
{
if (!HasDefaultSession)
{
DefaultSession = e;
UpdateDefaultChainAndNamespace();
await UpdateDefaultChainIdAndNamespaceAsync();
}
}

private async void UpdateDefaultChainAndNamespace()
private async Task UpdateDefaultChainIdAndNamespaceAsync()
{
try
if (HasDefaultSession)
{
if (HasDefaultSession)
// Check if current default namespace is still valid with the current session
var currentDefault = DefaultNamespace;
if (currentDefault != null && DefaultSession.Namespaces.ContainsKey(currentDefault))
{
var currentDefault = DefaultNamespace;
if (currentDefault != null && DefaultSession.Namespaces.ContainsKey(currentDefault))
// Check if current default chain is still valid with the current session
var currentChain = DefaultChainId;
if (currentChain == null || !DefaultSession.Namespaces[DefaultNamespace].Chains.Contains(currentChain))
{
// DefaultNamespace is still valid
var currentChain = DefaultChain;
if (currentChain == null ||
DefaultSession.Namespaces[DefaultNamespace].Chains.Contains(currentChain))
{
// DefaultChain is still valid
await SaveDefaults();
return;
}

DefaultChain = DefaultSession.Namespaces[DefaultNamespace].Chains[0];
await SaveDefaults();
return;
// If the current default chain is not valid, let's use the first one
DefaultChainId = DefaultSession.Namespaces[DefaultNamespace].Chains[0];
}

// DefaultNamespace is null or not found in current available spaces, update it
}
else
{
// If DefaultNamespace is null or not found in current available spaces, update it
DefaultNamespace = DefaultSession.Namespaces.Keys.FirstOrDefault();
if (DefaultNamespace != null)
if (DefaultNamespace != null && DefaultSession.Namespaces[DefaultNamespace].Chains != null)
{
if (DefaultSession.Namespaces.ContainsKey(DefaultNamespace) &&
DefaultSession.Namespaces[DefaultNamespace].Chains != null)
{
DefaultChain = DefaultSession.Namespaces[DefaultNamespace].Chains[0];
}
else if (DefaultSession.RequiredNamespaces.ContainsKey(DefaultNamespace) &&
DefaultSession.RequiredNamespaces[DefaultNamespace].Chains != null)
{
// We don't know what chain to use? Let's use the required one as a fallback
DefaultChain = DefaultSession.RequiredNamespaces[DefaultNamespace].Chains[0];
}
DefaultChainId = DefaultSession.Namespaces[DefaultNamespace].Chains[0];
}
else
{
DefaultNamespace = DefaultSession.Namespaces.Keys.FirstOrDefault();
if (DefaultNamespace != null && DefaultSession.Namespaces[DefaultNamespace].Chains != null)
{
DefaultChain = DefaultSession.Namespaces[DefaultNamespace].Chains[0];
}
else
{
// We don't know what chain to use? Let's use the required one as a fallback
DefaultNamespace = DefaultSession.RequiredNamespaces.Keys.FirstOrDefault();
if (DefaultNamespace != null &&
DefaultSession.RequiredNamespaces[DefaultNamespace].Chains != null)
{
DefaultChain = DefaultSession.RequiredNamespaces[DefaultNamespace].Chains[0];
}
else
{
WCLogger.LogError("Could not figure out default chain to use");
}
}
throw new InvalidOperationException("Could not figure out default chain and namespace");
}
}
else
{
DefaultNamespace = null;
}

await SaveDefaults();
}
catch (Exception e)
else
{
WCLogger.LogError(e);
throw;
DefaultNamespace = null;
DefaultChainId = null;
}
}

public Caip25Address CurrentAddress(string @namespace = null, SessionStruct session = default)
public async Task InitAsync()
{
@namespace ??= DefaultNamespace;
if (string.IsNullOrWhiteSpace(session.Topic)) // default
session = DefaultSession;
await this.LoadDefaults();
}

public async Task SetDefaultNamespaceAsync(string @namespace)
{
if (string.IsNullOrWhiteSpace(@namespace))
{
throw new ArgumentNullException(nameof(@namespace));
}

return session.CurrentAddress(@namespace);
if (!DefaultSession.Namespaces.ContainsKey(@namespace))
{
throw new InvalidOperationException($"Namespace {@namespace} is not available in the current session");
}

DefaultNamespace = @namespace;
await SaveDefaults();
}

public async Task SetDefaultChainIdAsync(string chainId)
{
if (string.IsNullOrWhiteSpace(chainId))
{
throw new ArgumentNullException(nameof(chainId));
}

public async Task Init()
if (!DefaultSession.Namespaces[DefaultNamespace].Chains.Contains(chainId))
{
throw new InvalidOperationException($"Chain {chainId} is not available in the current session");
}

DefaultChainId = chainId;
await SaveDefaults();
}

public Caip25Address CurrentAddress(string chainId = null, SessionStruct session = default)
{
await this.LoadDefaults();
chainId ??= DefaultChainId;
if (string.IsNullOrWhiteSpace(session.Topic))
{
session = DefaultSession;
}

return session.CurrentAddress(chainId);
}

public Caip25Address[] AllAddresses(string @namespace = null, SessionStruct session = default)
public IEnumerable<Caip25Address> AllAddresses(string @namespace = null, SessionStruct session = default)
{
@namespace ??= DefaultNamespace;
if (string.IsNullOrWhiteSpace(session.Topic)) // default
session = DefaultSession;

return session.AllAddresses(@namespace);
}

public void Dispose()
{
_client.SessionConnected -= ClientOnSessionConnected;
_client.SessionDeleted -= ClientOnSessionDeleted;
_client.SessionUpdated -= ClientOnSessionUpdated;
_client.SessionApproved -= ClientOnSessionConnected;

_client = null;
Sessions = null;
DefaultNamespace = null;
DefaultSession = default;
Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
{
if (_disposed)
{
return;
}

if (disposing)
{
_client.SessionConnected -= ClientOnSessionConnected;
_client.SessionDeleted -= ClientOnSessionDeleted;
_client.SessionUpdateRequest -= ClientOnSessionUpdated;
_client.SessionApproved -= ClientOnSessionConnected;

_client = null;
Sessions = null;
DefaultNamespace = null;
DefaultSession = default;
}

_disposed = true;
}
}
6 changes: 3 additions & 3 deletions WalletConnectSharp.Sign/Engine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ public Task<IAcknowledgement> Extend()
public Task<TR> Request<T, TR>(T data, string chainId = null, long? expiry = null)
{
return Request<T, TR>(Client.AddressProvider.DefaultSession.Topic, data,
chainId ?? Client.AddressProvider.DefaultChain, expiry);
chainId ?? Client.AddressProvider.DefaultChainId, expiry);
}

public Task Respond<T, TR>(JsonRpcResponse<TR> response)
Expand All @@ -336,7 +336,7 @@ public Task Respond<T, TR>(JsonRpcResponse<TR> response)
public Task Emit<T>(EventData<T> eventData, string chainId = null)
{
return Emit<T>(Client.AddressProvider.DefaultSession.Topic, eventData,
chainId ?? Client.AddressProvider.DefaultChain);
chainId ?? Client.AddressProvider.DefaultChainId);
}

public Task Ping()
Expand Down Expand Up @@ -735,7 +735,7 @@ public async Task<TR> Request<T, TR>(string topic, T data, string chainId = null
var sessionData = Client.Session.Get(topic);
var defaultNamespace = Client.AddressProvider.DefaultNamespace ??
sessionData.Namespaces.Keys.FirstOrDefault();
defaultChainId = Client.AddressProvider.DefaultChain ??
defaultChainId = Client.AddressProvider.DefaultChainId ??
sessionData.Namespaces[defaultNamespace].Chains[0];
}
else
Expand Down
Loading

0 comments on commit 2d45af3

Please sign in to comment.