diff --git a/src/index.ts b/src/index.ts index 7c8ab44..afddd92 100644 --- a/src/index.ts +++ b/src/index.ts @@ -42,6 +42,9 @@ export class TokenGate { urlFromReq: (req: Request) => string; tzAddrFromReq: (req: Request) => string | undefined; + applyAddressWhitelist: boolean; + maxWhitelistedClaims: number; + constructor({ dbPool }: { dbPool: DbPool }) { this.db = dbPool; @@ -59,6 +62,9 @@ export class TokenGate { this.urlFromReq = (req) => req.baseUrl; this.tzAddrFromReq = (req: any) => req.auth?.userAddress; + + this.applyAddressWhitelist = false; + this.maxWhitelistedClaims = 0; } loadSpecFromFile(filepath: string, overwrite: boolean = true): this { @@ -172,6 +178,12 @@ export class TokenGate { return this; } + enableAddressWhitelist(maxClaims: number = 1): this { + this.applyAddressWhitelist = true; + this.maxWhitelistedClaims = maxClaims; + return this; + } + getSpec(): TokenGateSpec { return Object.keys(this.rules).reduce((res, endpoint) => { res[endpoint] = { @@ -194,6 +206,7 @@ export class TokenGate { use(req: Request, resp: Response, next: NextFunction): void { const url = this.urlFromReq(req); const tzAddr = this.tzAddrFromReq(req); + this.hasAccess(url, tzAddr) .then((access) => { if (!access) { @@ -208,6 +221,28 @@ export class TokenGate { }); } + async isAddressInWhitelist( + userAddress?: string + ): Promise<"allowed" | "claimed" | "forbidden"> { + if (typeof userAddress === "undefined") { + return "forbidden"; + } + const qryResp = await this.db.query( + ` +SELECT claimed +FROM whitelisted_wallet_addresses +WHERE address = $1 + `, + [userAddress] + ); + if (qryResp.rowCount === 0) { + return "forbidden"; + } + return Number(qryResp.rows[0].claimed) >= this.maxWhitelistedClaims + ? "claimed" + : "allowed"; + } + async hasAccess( endpoint: Endpoint, tzAddr: string | undefined @@ -221,6 +256,12 @@ export class TokenGate { if (typeof tzAddr === "undefined") { return false; } + if ( + this.applyAddressWhitelist && + (await this.isAddressInWhitelist(tzAddr)) !== "allowed" + ) { + return false; + } return await this.#ownsOneOf(tzAddr, rule.allowedTokens); }