summaryrefslogtreecommitdiff
path: root/vendor/github.com/hashicorp/hil/check_identifier.go
blob: 474f50588e1759459d99c47f212473ac7ff8bc87 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package hil

import (
	"fmt"
	"sync"

	"github.com/hashicorp/hil/ast"
)

// IdentifierCheck is a SemanticCheck that checks that all identifiers
// resolve properly and that the right number of arguments are passed
// to functions.
type IdentifierCheck struct {
	Scope ast.Scope

	err  error
	lock sync.Mutex
}

func (c *IdentifierCheck) Visit(root ast.Node) error {
	c.lock.Lock()
	defer c.lock.Unlock()
	defer c.reset()
	root.Accept(c.visit)
	return c.err
}

func (c *IdentifierCheck) visit(raw ast.Node) ast.Node {
	if c.err != nil {
		return raw
	}

	switch n := raw.(type) {
	case *ast.Call:
		c.visitCall(n)
	case *ast.VariableAccess:
		c.visitVariableAccess(n)
	case *ast.Output:
		// Ignore
	case *ast.LiteralNode:
		// Ignore
	default:
		// Ignore
	}

	// We never do replacement with this visitor
	return raw
}

func (c *IdentifierCheck) visitCall(n *ast.Call) {
	// Look up the function in the map
	function, ok := c.Scope.LookupFunc(n.Func)
	if !ok {
		c.createErr(n, fmt.Sprintf("unknown function called: %s", n.Func))
		return
	}

	// Break up the args into what is variadic and what is required
	args := n.Args
	if function.Variadic && len(args) > len(function.ArgTypes) {
		args = n.Args[:len(function.ArgTypes)]
	}

	// Verify the number of arguments
	if len(args) != len(function.ArgTypes) {
		c.createErr(n, fmt.Sprintf(
			"%s: expected %d arguments, got %d",
			n.Func, len(function.ArgTypes), len(n.Args)))
		return
	}
}

func (c *IdentifierCheck) visitVariableAccess(n *ast.VariableAccess) {
	// Look up the variable in the map
	if _, ok := c.Scope.LookupVar(n.Name); !ok {
		c.createErr(n, fmt.Sprintf(
			"unknown variable accessed: %s", n.Name))
		return
	}
}

func (c *IdentifierCheck) createErr(n ast.Node, str string) {
	c.err = fmt.Errorf("%s: %s", n.Pos(), str)
}

func (c *IdentifierCheck) reset() {
	c.err = nil
}