Skip to content

Commit 612d71b

Browse files
committed
Adding auth to Notion extension
1 parent 61d942a commit 612d71b

6 files changed

Lines changed: 500 additions & 0 deletions

File tree

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using System;
2+
using System.Runtime.InteropServices;
3+
4+
namespace NotionExtension.Authentication;
5+
6+
public static class CredentialManager
7+
{
8+
#pragma warning disable CA1707 // Identifiers should not contain underscores
9+
public enum CRED_TYPE : int
10+
{
11+
GENERIC = 1,
12+
DOMAIN_PASSWORD = 2,
13+
DOMAIN_CERTIFICATE = 3,
14+
DOMAIN_VISIBLE_PASSWORD = 4,
15+
MAXIMUM = 5,
16+
}
17+
18+
public enum CRED_PERSIST : uint
19+
{
20+
Session = 1,
21+
LocalMachine = 2,
22+
Enterprise = 3,
23+
}
24+
25+
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
26+
public struct CREDENTIAL
27+
{
28+
public int Flags;
29+
public CRED_TYPE Type;
30+
[MarshalAs(UnmanagedType.LPWStr)]
31+
public string TargetName;
32+
[MarshalAs(UnmanagedType.LPWStr)]
33+
public string Comment;
34+
public System.Runtime.InteropServices.ComTypes.FILETIME LastWritten;
35+
public int CredentialBlobSize;
36+
public IntPtr CredentialBlob;
37+
public int Persist;
38+
public int AttributeCount;
39+
public IntPtr CredAttribute;
40+
[MarshalAs(UnmanagedType.LPWStr)]
41+
public string TargetAlias;
42+
[MarshalAs(UnmanagedType.LPWStr)]
43+
public string UserName;
44+
}
45+
46+
[DllImport("advapi32.dll", EntryPoint = "CredWriteW", CharSet = CharSet.Unicode, SetLastError = true)]
47+
[return: MarshalAs(UnmanagedType.Bool)]
48+
internal static extern bool CredWrite(CREDENTIAL credential, int flags);
49+
50+
[DllImport("advapi32.dll", EntryPoint = "CredDeleteW", CharSet = CharSet.Unicode)]
51+
[return: MarshalAs(UnmanagedType.Bool)]
52+
internal static extern bool CredDelete(string target, CRED_TYPE type, int flags);
53+
54+
[DllImport("advapi32.dll", EntryPoint = "CredReadW", CharSet = CharSet.Unicode, SetLastError = true)]
55+
[return: MarshalAs(UnmanagedType.Bool)]
56+
internal static extern bool CredRead(string target, CRED_TYPE type, int flags, out IntPtr credential);
57+
58+
[DllImport("advapi32.dll", EntryPoint = "CredEnumerateW", CharSet = CharSet.Unicode, SetLastError = true)]
59+
[return: MarshalAs(UnmanagedType.Bool)]
60+
internal static extern bool CredEnumerate(string filter, uint flags, out uint count, out IntPtr credentials);
61+
62+
[DllImport("advapi32.dll", EntryPoint = "CredFree", CharSet = CharSet.Unicode)]
63+
internal static extern void CredFree(IntPtr buffer);
64+
}
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.ComponentModel;
4+
using System.Diagnostics;
5+
using System.Linq;
6+
using System.Runtime.InteropServices;
7+
using System.Security;
8+
using Windows.Security.Credentials;
9+
using static NotionExtension.Authentication.CredentialManager;
10+
11+
namespace NotionExtension.Authentication;
12+
public class CredentialVault(string applicationName = "") : ICredentialVault
13+
{
14+
private readonly string _credentialResourceName = string.IsNullOrEmpty(applicationName) ? CredentialVaultConfiguration.CredResourceName : applicationName;
15+
16+
private static class CredentialVaultConfiguration
17+
{
18+
public const string CredResourceName = "CmdPalNotionExtension";
19+
}
20+
21+
// Win32 Error codes
22+
public const int Win32ErrorNotFound = 1168;
23+
24+
private string AddCredentialResourceNamePrefix(string loginId) => _credentialResourceName + ": " + loginId;
25+
26+
public void SaveCredentials(string loginId, SecureString? accessToken)
27+
{
28+
// Initialize a credential object.
29+
var credential = new CREDENTIAL
30+
{
31+
Type = CRED_TYPE.GENERIC,
32+
TargetName = AddCredentialResourceNamePrefix(loginId),
33+
UserName = loginId,
34+
Persist = (int)CRED_PERSIST.LocalMachine,
35+
AttributeCount = 0,
36+
Flags = 0,
37+
Comment = string.Empty,
38+
};
39+
40+
try
41+
{
42+
if (accessToken != null)
43+
{
44+
credential.CredentialBlob = Marshal.SecureStringToCoTaskMemUnicode(accessToken);
45+
credential.CredentialBlobSize = accessToken.Length * 2;
46+
}
47+
else
48+
{
49+
Debug.WriteLine($"The access token is null for the loginId provided");
50+
throw new ArgumentNullException(nameof(accessToken));
51+
}
52+
53+
// Store credential under Windows Credentials inside Credential Manager.
54+
var isCredentialSaved = CredWrite(credential, 0);
55+
if (!isCredentialSaved)
56+
{
57+
Debug.WriteLine($"Writing credentials to Credential Manager has failed");
58+
throw new Win32Exception(Marshal.GetLastWin32Error());
59+
}
60+
}
61+
finally
62+
{
63+
if (credential.CredentialBlob != IntPtr.Zero)
64+
{
65+
Marshal.FreeCoTaskMem(credential.CredentialBlob);
66+
}
67+
}
68+
}
69+
70+
public PasswordCredential? GetCredentials(string loginId)
71+
{
72+
var credentialNameToRetrieve = AddCredentialResourceNamePrefix(loginId);
73+
var ptrToCredential = IntPtr.Zero;
74+
75+
try
76+
{
77+
var isCredentialRetrieved = CredRead(credentialNameToRetrieve, CRED_TYPE.GENERIC, 0, out ptrToCredential);
78+
if (!isCredentialRetrieved)
79+
{
80+
var error = Marshal.GetLastWin32Error();
81+
Debug.WriteLine($"Retrieving credentials from Credential Manager has failed for {loginId} with {error}");
82+
83+
// NotFound is expected and can be ignored.
84+
return error == Win32ErrorNotFound ? null : throw new Win32Exception(error);
85+
}
86+
87+
CREDENTIAL credentialObject;
88+
if (ptrToCredential != IntPtr.Zero)
89+
{
90+
#pragma warning disable CS8605 // Unboxing a possibly null value.
91+
credentialObject = Marshal.PtrToStructure<CREDENTIAL>(ptrToCredential);
92+
#pragma warning restore CS8605 // Unboxing a possibly null value.
93+
94+
}
95+
else
96+
{
97+
Debug.WriteLine($"No credentials found for this DeveloperId : {loginId}");
98+
return null;
99+
}
100+
101+
var accessTokenInChars = new char[credentialObject.CredentialBlobSize / 2];
102+
Marshal.Copy(credentialObject.CredentialBlob, accessTokenInChars, 0, accessTokenInChars.Length);
103+
104+
// convert accessTokenInChars to string
105+
string accessTokenString = new(accessTokenInChars);
106+
107+
for (var i = 0; i < accessTokenInChars.Length; i++)
108+
{
109+
// Zero out characters after they are copied over from an unmanaged to managed type.
110+
accessTokenInChars[i] = '\0';
111+
}
112+
113+
var credential = new PasswordCredential(_credentialResourceName, loginId, accessTokenString);
114+
return credential;
115+
}
116+
catch (Exception ex)
117+
{
118+
Debug.WriteLine(ex, $"Retrieving credentials from Credential Manager has failed unexpectedly: {loginId} : ");
119+
throw;
120+
}
121+
finally
122+
{
123+
if (ptrToCredential != IntPtr.Zero)
124+
{
125+
CredFree(ptrToCredential);
126+
}
127+
}
128+
}
129+
130+
public void RemoveCredentials(string loginId)
131+
{
132+
var targetCredentialToDelete = AddCredentialResourceNamePrefix(loginId);
133+
var isCredentialDeleted = CredDelete(targetCredentialToDelete, CRED_TYPE.GENERIC, 0);
134+
if (!isCredentialDeleted)
135+
{
136+
Debug.WriteLine($"Deleting credentials from Credential Manager has failed for {loginId}");
137+
}
138+
}
139+
140+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0301:Simplify collection initialization", Justification = "Leaving return type makes code clearer")]
141+
public IEnumerable<string> GetAllCredentials()
142+
{
143+
var ptrToCredential = IntPtr.Zero;
144+
145+
try
146+
{
147+
IntPtr[] allCredentials;
148+
uint count;
149+
150+
if (CredEnumerate(_credentialResourceName + "*", 0, out count, out ptrToCredential))
151+
{
152+
allCredentials = new IntPtr[count];
153+
Marshal.Copy(ptrToCredential, allCredentials, 0, (int)count);
154+
}
155+
else
156+
{
157+
var error = Marshal.GetLastWin32Error();
158+
159+
// NotFound is expected and can be ignored.
160+
return error == Win32ErrorNotFound ? Enumerable.Empty<string>() : throw new InvalidOperationException();
161+
}
162+
163+
if (count is 0)
164+
{
165+
return Enumerable.Empty<string>();
166+
}
167+
168+
var allLoginIds = new List<string>();
169+
for (var i = 0; i < allCredentials.Length; i++)
170+
{
171+
#pragma warning disable CS8605 // Unboxing a possibly null value.
172+
var credential = Marshal.PtrToStructure<CREDENTIAL>(allCredentials[i]);
173+
#pragma warning restore CS8605 // Unboxing a possibly null value.
174+
allLoginIds.Add(credential.UserName);
175+
}
176+
177+
return allLoginIds;
178+
}
179+
finally
180+
{
181+
if (ptrToCredential != IntPtr.Zero)
182+
{
183+
CredFree(ptrToCredential);
184+
}
185+
}
186+
}
187+
188+
public void RemoveAllCredentials()
189+
{
190+
var allCredentials = GetAllCredentials();
191+
foreach (var credential in allCredentials)
192+
{
193+
try
194+
{
195+
RemoveCredentials(credential);
196+
}
197+
catch (Exception ex)
198+
{
199+
Debug.WriteLine(ex, $"Deleting credentials from Credential Manager has failed unexpectedly: {credential} : ");
200+
}
201+
}
202+
}
203+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System.Collections.Generic;
2+
using System.Security;
3+
using Windows.Security.Credentials;
4+
5+
namespace NotionExtension.Authentication;
6+
7+
public interface ICredentialVault
8+
{
9+
PasswordCredential? GetCredentials(string loginId);
10+
11+
void RemoveCredentials(string loginId);
12+
13+
void SaveCredentials(string loginId, SecureString? accessToken);
14+
15+
IEnumerable<string> GetAllCredentials();
16+
17+
void RemoveAllCredentials();
18+
}

0 commit comments

Comments
 (0)