diff --git a/build.sh b/build.sh old mode 100755 new mode 100644 diff --git a/resource/checker/checker-config.json b/resource/checker/checker-config.json index fd77c86d..5e28f533 100644 --- a/resource/checker/checker-config.json +++ b/resource/checker/checker-config.json @@ -146,6 +146,11 @@ "checkerPath": "checker/taint/go/restful-entrypoint-collect-checker.ts", "description": "go-restful entrypoint采集以及框架source添加" }, + { + "checkerId": "echo-entrypoint-collect-checker", + "checkerPath": "checker/taint/go/echo-entrypoint-collect-checker.ts", + "description": "echo entrypoint采集以及框架source添加" + }, { "checkerId": "get_file_ast", "checkerPath": "checker/sdk/get-file-ast-checker.ts", diff --git a/resource/checker/checker-pack-config.json b/resource/checker/checker-pack-config.json index d35a77e5..2380e197 100644 --- a/resource/checker/checker-pack-config.json +++ b/resource/checker/checker-pack-config.json @@ -6,6 +6,7 @@ "taint_flow_go_input", "cobra.Command-builtIn", "go-restful-entryPoints-collect-checker", + "echo-entrypoint-collect-checker", "gorilla-mux-entrypoint-collect-checker", "gRpc-entryPoint-collect-checker", "go-main-entryPoints-collection", diff --git a/resource/example-rule-config/rule_config_go.json b/resource/example-rule-config/rule_config_go.json index 2770fc3d..71dde301 100644 --- a/resource/example-rule-config/rule_config_go.json +++ b/resource/example-rule-config/rule_config_go.json @@ -468,6 +468,15 @@ } ], "FuncCallArgTaintSource": [ + { + "args": [ + "0" + ], + "calleeType": "echo.Context", + "fsig": "Bind", + "scopeFile": "all", + "scopeFunc": "all" + }, { "args": [ "0" diff --git a/src/checker/taint/go/echo-entrypoint-collect-checker.ts b/src/checker/taint/go/echo-entrypoint-collect-checker.ts new file mode 100644 index 00000000..042bd93c --- /dev/null +++ b/src/checker/taint/go/echo-entrypoint-collect-checker.ts @@ -0,0 +1,439 @@ +import type Unit from '../../../engine/analyzer/common/value/unit' +import { flattenUnionValues, processEntryPointAndTaintSource } from './util' + +const config = require('../../../config') +const GoAnalyzer = require('../../../engine/analyzer/golang/common/go-analyzer') + +const KnownPackageName = { + 'github.com/labstack/echo/v4': 'echo', + 'github.com/labstack/echo-jwt/v4': 'echojwt', +} + +const RouteRegistryObject = ['github.com/labstack/echo/v4.New()'] + +const MiddlewareHandlerRegistryObject = [ + 'github.com/labstack/echo/v4/middleware', + 'github.com/labstack/echo-contrib/casbin', + 'github.com/labstack/echo-jwt/v4', + 'github.com/labstack/echo-contrib/echoprometheus', + 'github.com/labstack/echo-contrib/session', +] + +const ConfigObjectCollectionTable = new Map>([ + [ + 'BasicAuthWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'Validator', source: '0, 1, 2' }, + ], + ], + [ + 'BodyDumpWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'Handler', source: '0, 1, 2' }, + ], + ], + ['BodyLimitWithConfig', [{ name: 'Skipper', source: '0' }]], + [ + 'MiddlewareWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'EnforceHandler', source: '0, 1' }, + { name: 'UserGetter', source: '0' }, + { name: 'ErrorHandler', source: '0, 1' }, + ], + ], + [ + 'ContextTimeoutWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'ErrorHandler', source: '0, 1' }, + ], + ], + [ + 'CORSWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'AllowOriginFunc', source: '0' }, + ], + ], + [ + 'CSRFWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'ErrorHandler', source: '0, 1' }, + ], + ], + ['DecompressWithConfig', [{ name: 'Skipper', source: '0' }]], + ['GzipWithConfig', [{ name: 'Skipper', source: '0' }]], + [ + 'WithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'BeforeFunc', source: '0' }, + { name: 'SuccessHandler', source: '0' }, + { name: 'ErrorHandler', source: '0, 1' }, + { name: 'KeyFunc', source: '0' }, + { name: 'ParseTokenFunc', source: '0' }, + { name: 'NewClaimsFunc', source: '0' }, + ], + ], + [ + 'KeyAuthWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'Validator', source: '0, 1' }, + { name: 'ErrorHandler', source: '0, 1' }, + ], + ], + [ + 'LoggerWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'CustomTagFunc', source: '0' }, + ], + ], + [ + 'RequestLoggerWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'BeforeNextFunc', source: '0' }, + { name: 'LogValuesFunc', source: '0, 1' }, + ], + ], + [ + 'MethodOverrideWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'Getter', source: '0' }, + ], + ], + [ + 'NewMiddlewareWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'BeforeNext', source: '0' }, + { name: 'AfterNext', source: '0, 1' }, + { name: 'StatusCodeResolver', source: '0, 1' }, + ], + ], + ['Proxy', [{ name: 'Next', source: '0' }]], + [ + 'ProxyWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'RetryFilter', source: '0, 1' }, + { name: 'ErrorHandler', source: '0, 1' }, + { name: 'ModifyResponse', source: '0' }, + ], + ], + ['RateLimiter', [{ name: 'Allow', source: '0' }]], + [ + 'RateLimiterWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'BeforeFunc', source: '0' }, + { name: 'IdentifierExtractor', source: '0' }, + { name: 'ErrorHandler', source: '0, 1' }, + { name: 'DenyHandler', source: '0, 1, 2' }, + ], + ], + [ + 'RecoverWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'LogErrorFunc', source: '0, 1' }, + ], + ], + ['HTTPSRedirectWithConfig', [{ name: 'Skipper', source: '0' }]], + [ + 'RequestIDWithConfig', + [ + { name: 'Skipper', source: '0' }, + { name: 'RequestIDHandler', source: '0, 1' }, + ], + ], + ['RewriteWithConfig', [{ name: 'Skipper', source: '0' }]], + ['SecureWithConfig', [{ name: 'Skipper', source: '0' }]], + [ + 'Middleware', + [ + { name: 'Get', source: '0' }, + { name: 'New', source: '0' }, + { name: 'Save', source: '0' }, + ], + ], + ['StaticWithConfig', [{ name: 'Skipper', source: '0' }]], + ['AddTrailingSlashWithConfig', [{ name: 'Skipper', source: '0' }]], + ['RemoveTrailingSlashWithConfig', [{ name: 'Skipper', source: '0' }]], +]) + +const Checker = require('../../common/checker') + +const processedRouteRegistry = new Set() + +/** + * + */ +class EchoEntrypointCollectChecker extends Checker { + /** + * + * @param resultManager + */ + constructor(resultManager: any) { + super(resultManager, 'echo-entrypoint-collect-checker') + GoAnalyzer.registerKnownPackageNames(KnownPackageName) + } + + /** + * + * @param analyzer + * @param scope + * @param node + * @param state + * @param info + */ + triggerAtFunctionCallBefore(analyzer: any, scope: any, node: any, state: any, info: any) { + const { fclos, argvalues } = info + if (config.entryPointMode === 'ONLY_CUSTOM') return + if (!(fclos && fclos.object && fclos.property)) return + const { object, property } = fclos + if (!object._qid || !property.name) return + if (!RouteRegistryObject.some((obj) => object._qid.includes(obj))) return + switch (property.name) { + case 'Use': + case 'Pre': + this.handleMiddlewareArgs(analyzer, scope, state, argvalues) + break + case 'CONNECT': + case 'DELETE': + case 'GET': + case 'HEAD': + case 'OPTIONS': + case 'PATCH': + case 'POST': + case 'PUT': + case 'TRACE': + case 'RouteNotFound': + case 'Any': + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, argvalues[1], '0') + this.handleMiddlewareArgs(analyzer, scope, state, argvalues.slice(2)) + break + case 'Match': + case 'Add': + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, argvalues[2], '0') + this.handleMiddlewareArgs(analyzer, scope, state, argvalues.slice(3)) + break + case 'File': + this.handleMiddlewareArgs(analyzer, scope, state, argvalues.slice(2)) + break + case 'FileFS': + flattenUnionValues([argvalues[2]]).forEach((fs) => { + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, fs.field.Open, '0') + }) + this.handleMiddlewareArgs(analyzer, scope, state, argvalues.slice(3)) + break + case 'Host': + case 'Group': + this.handleMiddlewareArgs(analyzer, scope, state, argvalues.slice(1)) + break + default: + break + } + } + + /** + * + * @param analyzer + * @param scope + * @param node + * @param state + * @param info + */ + triggerAtSymbolInterpretOfEntryPointAfter(analyzer: any, scope: any, node: any, state: any, info: any) { + if (info?.entryPoint.functionName === 'main') processedRouteRegistry.clear() + } + + /** + * + * @param analyzer + * @param scope + * @param node + * @param state + * @param info + */ + triggerAtAssignment(analyzer: any, scope: any, node: any, state: any, info: any) { + if (config.entryPointMode === 'ONLY_CUSTOM') return + const { lvalue, rvalue } = info + if (!(lvalue.object && lvalue.property)) return + const { object, property } = lvalue + if (!object._qid || !property.name) return + if (!RouteRegistryObject.some((obj) => object._qid.includes(obj))) return + const rvalueObjs = flattenUnionValues([rvalue]) + switch (property.name) { + case 'HTTPErrorHandler': + rvalueObjs.forEach((obj) => + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, obj, '0, 1') + ) + break + case 'Binder': + rvalueObjs.forEach((obj) => + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, obj.field.Bind, '1') + ) + break + case 'Renderer': + rvalueObjs.forEach((obj) => + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, obj.field.Render, '3') + ) + break + case 'Filesystem': + rvalueObjs.forEach((obj) => + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, obj.field.Open, '0') + ) + break + default: + break + } + } + + /** + * + * @param analyzer + * @param state + * @param symbol + */ + handleConfigObjectCollection(analyzer: any, state: any, symbol: any) { + const rules = ConfigObjectCollectionTable.get(symbol.expression?.name) + if (!rules) return + flattenUnionValues([symbol.arguments[0]]).forEach((middlewareConfig) => { + rules.forEach((rule) => { + const fieldValue = middlewareConfig.field[rule.name] + if (!fieldValue) return + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, fieldValue, rule.source) + }) + }) + } + + /** + * + * @param analyzer + * @param state + * @param symbol + */ + handleKnownEchoMiddlewares(analyzer: any, state: any, symbol: any) { + if (symbol.type !== 'CallExpression') return + const objectQid = symbol.expression?._qid + if (!(objectQid && MiddlewareHandlerRegistryObject.some((obj) => objectQid.startsWith(obj)))) return + + switch (symbol.expression.name) { + case 'BasicAuth': + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, symbol.arguments[0], '0, 1, 2') + break + case 'BodyDump': + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, symbol.arguments[0], '0, 1, 2') + break + case 'KeyAuth': + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, symbol.arguments[0], '0, 1') + break + case 'WithConfig': + flattenUnionValues([symbol.arguments[0]]).forEach((middlewareConfig) => { + const tokenLookupFuncs = middlewareConfig.field.TokenLookupFuncs + if (!tokenLookupFuncs) return + Object.values(tokenLookupFuncs.value).forEach((v) => { + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, v as Unit, '0') + }) + }) + this.handleConfigObjectCollection(analyzer, state, symbol) + break + case 'NewMiddlewareWithConfig': + flattenUnionValues([symbol.arguments[0]]).forEach((middlewareConfig) => { + const labelFuncs = middlewareConfig.field.LabelFuncs + if (!labelFuncs) return + Object.values(labelFuncs.value).forEach((v) => { + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, v as Unit, '0, 1') + }) + }) + this.handleConfigObjectCollection(analyzer, state, symbol) + break + case 'ProxyWithConfig': + flattenUnionValues([symbol.arguments[0]]).forEach((middlewareConfig) => { + const balancerNext = middlewareConfig.field?.Balancer?.field?.Next + if (balancerNext) { + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, balancerNext, '0') + } + const transportRoundTrip = middlewareConfig.field?.Transport?.field?.RoundTrip + if (transportRoundTrip) { + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, transportRoundTrip, '0') + } + }) + this.handleConfigObjectCollection(analyzer, state, symbol) + break + case 'RateLimiterWithConfig': + flattenUnionValues([symbol.arguments[0]]).forEach((middlewareConfig) => { + const allow = middlewareConfig.field?.Store?.field?.Allow + if (allow) { + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, allow, '0') + } + }) + this.handleConfigObjectCollection(analyzer, state, symbol) + break + case 'MiddlewareWithConfig': + flattenUnionValues([symbol.arguments[0]]).forEach((middlewareConfig) => { + const store = middlewareConfig.field.Store + if (!store) return + ;[store.field.Get, store.field.New, store.field.Save] + .filter((v) => v) + .forEach((v) => { + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, v, '0') + }) + }) + this.handleConfigObjectCollection(analyzer, state, symbol) + break + case 'StaticWithConfig': + flattenUnionValues([symbol.arguments[0]]).forEach((middlewareConfig) => { + const open = middlewareConfig.field?.Filesystem?.field?.Open + if (open) { + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, open, '0') + } + }) + this.handleConfigObjectCollection(analyzer, state, symbol) + break + default: + this.handleConfigObjectCollection(analyzer, state, symbol) + break + } + } + + /** + * + * @param analyzer + * @param scope + * @param state + * @param middlewareFunctionValue + */ + handleCustomMiddleware(analyzer: any, scope: any, state: any, middlewareFunctionValue: any) { + const retVal = analyzer.processAndCallFuncDef(scope, middlewareFunctionValue.fdef, middlewareFunctionValue, state) + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, retVal, '0') + } + + /** + * + * @param analyzer + * @param scope + * @param state + * @param list + */ + handleMiddlewareArgs(analyzer: any, scope: any, state: any, list: Array) { + const flattened = flattenUnionValues(list) + flattened.forEach((unit) => { + if (unit.vtype === 'symbol') { + this.handleKnownEchoMiddlewares(analyzer, state, unit) + } else if (unit.vtype === 'fclos') { + this.handleCustomMiddleware(analyzer, scope, state, unit) + } + }) + } +} + +module.exports = EchoEntrypointCollectChecker diff --git a/src/checker/taint/go/main-entrypoint-collect-checker.ts b/src/checker/taint/go/main-entrypoint-collect-checker.ts index 35f834f2..fc2e8254 100644 --- a/src/checker/taint/go/main-entrypoint-collect-checker.ts +++ b/src/checker/taint/go/main-entrypoint-collect-checker.ts @@ -2,7 +2,7 @@ import type { EntryPoint } from '../../../engine/analyzer/common/entrypoint' const _ = require('lodash') const GoEntryPoint = require('../../../engine/analyzer/golang/common/entrypoint-collector/go-default-entrypoint') -const { completeEntryPoint } = require('../common-kit/entry-points-util') +const completeEntryPoint = require('../common-kit/entry-points-util') const Config = require('../../../config') const Checker = require('../../common/checker') diff --git a/src/checker/taint/go/restful-entrypoint-collect-checker.ts b/src/checker/taint/go/restful-entrypoint-collect-checker.ts index 8cd40af8..6fca9f0c 100644 --- a/src/checker/taint/go/restful-entrypoint-collect-checker.ts +++ b/src/checker/taint/go/restful-entrypoint-collect-checker.ts @@ -1,15 +1,20 @@ +import { processEntryPointAndTaintSource } from './util' + const config = require('../../../config') +const GoAnalyzer = require('../../../engine/analyzer/golang/common/go-analyzer') const RouteRegistryProperty = ['Filter', 'To', 'If'] +const KnownPackageName = { + 'github.com/emicklei/go-restful': 'restful', + 'github.com/emicklei/go-restful/v3': 'restful', +} const RouteRegistryObject = [ - 'github.com/emicklei/go-restful/v3.WebService', 'github.com/emicklei/go-restful.WebService', + 'github.com/emicklei/go-restful/v3.WebService', ] -const IntroduceTaint = require('../common-kit/source-util') const Checker = require('../../common/checker') -const completeEntryPoint = require('../common-kit/entry-points-util') -const processedRouteRegistry = new Set() +const processedRouteRegistry = new Set() /** * @@ -21,6 +26,7 @@ class RestfulEntrypointCollectChecker extends Checker { */ constructor(resultManager: any) { super(resultManager, 'go-restful-entryPoints-collect-checker') + GoAnalyzer.registerKnownPackageNames(KnownPackageName) } /** @@ -67,20 +73,10 @@ class RestfulEntrypointCollectChecker extends Checker { const propertyName = property.name if ( RouteRegistryObject.some((prefix) => objectQid.startsWith(prefix)) && - RouteRegistryProperty.includes(propertyName) + RouteRegistryProperty.includes(propertyName) && + argValues[0] ) { - if (argValues.length < 1) return - const arg0 = argValues[0] - - if (arg0?.vtype === 'fclos' && arg0?.ast.loc) { - const hash = JSON.stringify(arg0.ast.loc) - if (!processedRouteRegistry.has(hash)) { - processedRouteRegistry.add(hash) - IntroduceTaint.introduceFuncArgTaintBySelfCollection(arg0, state, analyzer, '0', 'GO_INPUT') - const entryPoint = completeEntryPoint(arg0) - analyzer.entryPoints.push(entryPoint) - } - } + processEntryPointAndTaintSource(analyzer, state, processedRouteRegistry, argValues[0], '0') } } } diff --git a/src/checker/taint/go/util.ts b/src/checker/taint/go/util.ts new file mode 100644 index 00000000..13c27790 --- /dev/null +++ b/src/checker/taint/go/util.ts @@ -0,0 +1,53 @@ +import type Unit from '../../../engine/analyzer/common/value/unit' + +const IntroduceTaint = require('../common-kit/source-util') +const completeEntryPoint = require('../common-kit/entry-points-util') + +/** + * + * @param list + */ +export function flattenUnionValues(list: Array): Array { + return list.flatMap((unit) => { + switch (unit.vtype) { + case 'union': + return flattenUnionValues(unit.value) + case 'fclos': + case 'symbol': + case 'object': + return [unit] + default: + throw new Error(`flattenUnionValues: Unknown type ${unit.vtype}`) + } + }) +} + +/** + * + * @param analyzer + * @param state + * @param processedRouteRegistry + * @param entryPointUnitValue + * @param source + */ +export function processEntryPointAndTaintSource( + analyzer: any, + state: any, + processedRouteRegistry: Set, + entryPointUnitValue: Unit, + source: string +) { + flattenUnionValues([entryPointUnitValue]) + .filter((val) => val.vtype === 'fclos') + .forEach((entryPointFuncValue) => { + if (entryPointFuncValue?.ast.loc) { + const hash = JSON.stringify(entryPointFuncValue.ast.loc) + if (!processedRouteRegistry.has(hash)) { + processedRouteRegistry.add(hash) + IntroduceTaint.introduceFuncArgTaintBySelfCollection(entryPointFuncValue, state, analyzer, source, 'GO_INPUT') + const entryPoint = completeEntryPoint(entryPointFuncValue) + analyzer.entryPoints.push(entryPoint) + } + } + }) +} diff --git a/src/engine/analyzer/common/analyzer.ts b/src/engine/analyzer/common/analyzer.ts index 12757037..924607dd 100644 --- a/src/engine/analyzer/common/analyzer.ts +++ b/src/engine/analyzer/common/analyzer.ts @@ -1,5 +1,3 @@ -import { floor } from 'lodash' - const _ = require('lodash') const Uuid = require('node-uuid') const chalk = require('chalk') @@ -1486,8 +1484,9 @@ class Analyzer extends MemSpace { * @param scope * @param node * @param state + * @param cachedFclos */ - processCallExpression(scope: any, node: any, state: any) { + processCallExpression(scope: any, node: any, state: any, cachedFclos?: any) { /* { callee, arguments, } @@ -1498,7 +1497,7 @@ class Analyzer extends MemSpace { einfo: state.einfo, }) - const fclos = this.processInstruction(scope, node.callee, state) + const fclos = cachedFclos ?? this.processInstruction(scope, node.callee, state) if (!fclos) return UndefinedValue() if (node?.callee?.type === 'MemberAccess' && fclos.fdef && node.callee?.object?.type !== 'SuperExpression') { fclos._this = this.processInstruction(scope, node.callee.object, state) diff --git a/src/engine/analyzer/golang/common/go-analyzer.ts b/src/engine/analyzer/golang/common/go-analyzer.ts index 8ca4dfb9..52420626 100644 --- a/src/engine/analyzer/golang/common/go-analyzer.ts +++ b/src/engine/analyzer/golang/common/go-analyzer.ts @@ -348,7 +348,7 @@ class GoAnalyzer extends Analyzer { ainfo: this.ainfo, }) } - ret = super.processCallExpression(scope, node, state) + ret = super.processCallExpression(scope, node, state, fclos) if (ret && this.checkerManager) { this.checkerManager.checkAtFunctionCallAfter(this, scope, node, state, { fclos, @@ -517,6 +517,16 @@ class GoAnalyzer extends Analyzer { } } + private static knownPackageName: Record = {} + + /** + * Register known module name to package name mapping for default import variable name fix + * @param knownPackageName A map from module name to known package name + */ + static registerKnownPackageNames(knownPackageName: Record) { + GoAnalyzer.knownPackageName = { ...GoAnalyzer.knownPackageName, ...knownPackageName } + } + /** * * @param scope @@ -594,14 +604,15 @@ class GoAnalyzer extends Analyzer { } } } else { - // 如果是import,则定义真正的包名而非目录名 - if ( - initialNode?.type === 'ImportExpression' && - initVal?.vtype === 'package' && - initVal.name && - id.name === initialNode.from?.value?.split('/').at(-1) - ) { - id.name = initVal.name + if (initialNode?.type === 'ImportExpression') { + // 处理 default import 情况 + if (node._meta?.isDefaultImport === true && GoAnalyzer.knownPackageName[initialNode.from?.value]) { + id.name = GoAnalyzer.knownPackageName[initialNode.from?.value] + } + // 如果是import,则定义真正的包名而非目录名 + if (initVal?.vtype === 'package' && initVal.name && id.name === initialNode.from?.value?.split('/').at(-1)) { + id.name = initVal.name + } } this.saveVarInCurrentScope(scope, id, initVal, state) }