chibiccを読む~Cコンパイラコードリーディング~ ステップ14

トップページ
jupiteroak.hatenablog.com
「低レイヤを知りたい人のためのCコンパイラ作成入門」のCコンパイラを読んでいきます。
www.sigbus.info
ステップ14に該当
github.com
ステップ14に該当
github.com
ステップ14に該当
github.com

今回作成するコンパイラ

関数の呼び出しに対応する(引数なしの関数を呼び出せる、最大引数6つの関数を呼び出せる、関数呼び出しの前にRSPが16の倍数になっている、コンパイラを作成する)

追加・修正されたコンパイラソースコード(引数なしの関数を呼び出せるコンパイラを作成する)

NodeKind

https://github.com/rui314/chibicc/commit/f5540b578e4bafa57d7ff8d94f4a0a46c95ede12#diff-d06dbb7ef5899cdf50b340464444680b13aded45363e7aba944dc3551fdf6334R72
https://github.com/rui314/chibicc/blob/f5540b578e4bafa57d7ff8d94f4a0a46c95ede12/chibicc.h#L72

// AST node
typedef enum {
  ND_ADD,       // +
  ND_SUB,       // -
  ND_MUL,       // *
  ND_DIV,       // /
  ND_EQ,        // ==
  ND_NE,        // !=
  ND_LT,        // <
  ND_LE,        // <=
  ND_ASSIGN,    // =
  ND_RETURN,    // "return"
  ND_IF,        // "if"
  ND_WHILE,     // "while"
  ND_FOR,       // "for"
  ND_BLOCK,     // { ... }
  ND_FUNCALL,   // Function call
  ND_EXPR_STMT, // Expression statement
  ND_VAR,       // Variable
  ND_NUM,       // Integer
} NodeKind;

関数呼び出しを表現するノード型 ND_FUNCALLを追加します。

Node構造体

https://github.com/rui314/chibicc/commit/f5540b578e4bafa57d7ff8d94f4a0a46c95ede12#diff-d06dbb7ef5899cdf50b340464444680b13aded45363e7aba944dc3551fdf6334R98
https://github.com/rui314/chibicc/blob/f5540b578e4bafa57d7ff8d94f4a0a46c95ede12/chibicc.h#L98

// AST node type
typedef struct Node Node;
struct Node {
  NodeKind kind; // Node kind
  Node *next;    // Next node

  Node *lhs;     // Left-hand side
  Node *rhs;     // Right-hand side

  // "if, "while" or "for" statement
  Node *cond;
  Node *then;
  Node *els;
  Node *init;
  Node *inc;

  // Block
  Node *body;

  // Function call
  char *funcname;

  Var *var;      // Used if kind == ND_VAR
  int val;       // Used if kind == ND_NUM
};

関数名を保存するメンバfuncnameを追加します。

primary関数

https://github.com/rui314/chibicc/commit/f5540b578e4bafa57d7ff8d94f4a0a46c95ede12#diff-a07721cd062be25900bddb926de15fc103cf32ea2726d1fea286f6548b810c6aR251
https://github.com/rui314/chibicc/blob/f5540b578e4bafa57d7ff8d94f4a0a46c95ede12/parse.c#L251

Node *primary() {
  if (consume("(")) {
    Node *node = expr();
    expect(")");
    return node;
  }

  Token *tok = consume_ident();
  if (tok) {
    if (consume("(")) {
      expect(")");
      Node *node = new_node(ND_FUNCALL);
      node->funcname = strndup(tok->str, tok->len);
      return node;
    }

    Var *var = find_var(tok);
    if (!var)
      var = push_var(strndup(tok->str, tok->len));
    return new_var(var);
  }

  return new_num(expect_number());
}

primary関数は、生成規則 primary = "(" expr ")" | ident args? | num に基づいて、抽象構文木のノードを生成します。
(argsの生成規則は args = "(" ")" です。)

"("、expr、")"(変更なし)
  if (consume("(")) {
    Node *node = expr();
    expect(")");
    return node;
  }
ident、argsを0回か1回
    Token *tok = consume_ident();
  if (tok) {
    if (consume("(")) {
      expect(")");
      Node *node = new_node(ND_FUNCALL);
      node->funcname = strndup(tok->str, tok->len);
      return node;
    }

    Var *var = find_var(tok);
    if (!var)
      var = push_var(strndup(tok->str, tok->len));
    return new_var(var);
  }

consume_ident関数を呼び出してtokが真となる場合 → 識別子(ローカル変数、または、関数名)を表現するトークンを取得できた場合の処理です。
consume("(")の戻り値がtrueとなる場合 → 識別子の次のトークンが"("の場合 → 識別子が関数名の場合は、expect(")")を呼び出してトークン列を1つ読み進めます。
new_node関数を呼び出して、関数呼び出しを表すノードを生成し、生成したノードに関数名を記録します。

識別子がローカル変数の場合は、find_var関数を呼び出して、Var構造体からなる連結リストからtokが持つ文字列と同じ名前を持つvar構造体変数を取得します。
!varが真となる場合→tokが持つ文字列と同じ名前を持つvar構造体変数を取得できなかった場合は、push_var関数を呼び出して、var構造体を新規に生成しvar構造体の連結リストに追加します。
最後に、new_var関数を呼び出して、ローカル変数に対応するノードを新規に生成します。

num(変更なし)
  return new_num(expect_number());

gen関数

https://github.com/rui314/chibicc/commit/f5540b578e4bafa57d7ff8d94f4a0a46c95ede12#diff-629fe11334ae1d560032cdb6cc6f9a4fbb0f5b1365894b6b648d6ee4d5a654beR104
https://github.com/rui314/chibicc/blob/f5540b578e4bafa57d7ff8d94f4a0a46c95ede12/codegen.c#L104

void gen(Node *node) {
  switch (node->kind) {
  case ND_NUM:
    printf("  push %d\n", node->val);
    return;
  case ND_EXPR_STMT:
    gen(node->lhs);
    printf("  add rsp, 8\n");
    return;
  case ND_VAR:
    gen_addr(node);
    load();
    return;
  case ND_ASSIGN:
    gen_addr(node->lhs);
    gen(node->rhs);
    store();
    return;
  case ND_IF: {
    int seq = labelseq++;
    if (node->els) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lelse%d\n", seq);
      gen(node->then);
      printf("  jmp .Lend%d\n", seq);
      printf(".Lelse%d:\n", seq);
      gen(node->els);
      printf(".Lend%d:\n", seq);
    } else {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
      gen(node->then);
      printf(".Lend%d:\n", seq);
    }
    return;
  }
  case ND_WHILE: {
    int seq = labelseq++;
    printf(".Lbegin%d:\n", seq);
    gen(node->cond);
    printf("  pop rax\n");
    printf("  cmp rax, 0\n");
    printf("  je  .Lend%d\n", seq);
    gen(node->then);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_FOR: {
    int seq = labelseq++;
    if (node->init)
      gen(node->init);
    printf(".Lbegin%d:\n", seq);
    if (node->cond) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
    }
    gen(node->then);
    if (node->inc)
      gen(node->inc);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next)
      gen(n);
    return;
  case ND_FUNCALL:
    printf("  call %s\n", node->funcname);
    printf("  push rax\n");
    return;
  case ND_RETURN:
    gen(node->lhs);
    printf("  pop rax\n");
    printf("  jmp .Lreturn\n");
    return;
  }

  gen(node->lhs);
  gen(node->rhs);

  printf("  pop rdi\n");
  printf("  pop rax\n");

  switch (node->kind) {
  case ND_ADD:
    printf("  add rax, rdi\n");
    break;
  case ND_SUB:
    printf("  sub rax, rdi\n");
    break;
  case ND_MUL:
    printf("  imul rax, rdi\n");
    break;
  case ND_DIV:
    printf("  cqo\n");
    printf("  idiv rdi\n");
    break;
  case ND_EQ:
    printf("  cmp rax, rdi\n");
    printf("  sete al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_NE:
    printf("  cmp rax, rdi\n");
    printf("  setne al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LT:
    printf("  cmp rax, rdi\n");
    printf("  setl al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LE:
    printf("  cmp rax, rdi\n");
    printf("  setle al\n");
    printf("  movzb rax, al\n");
    break;
  }

  printf("  push rax\n");
}
二項演算以外を行うアセンブリコードを生成する
void gen(Node *node) {
  switch (node->kind) {
  case ND_NUM:
    printf("  push %d\n", node->val);
    return;
  case ND_EXPR_STMT:
    gen(node->lhs);
    printf("  add rsp, 8\n");
    return;
  case ND_VAR:
    gen_addr(node);
    load();
    return;
  case ND_ASSIGN:
    gen_addr(node->lhs);
    gen(node->rhs);
    store();
    return;
  case ND_IF: {
    int seq = labelseq++;
    if (node->els) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lelse%d\n", seq);
      gen(node->then);
      printf("  jmp .Lend%d\n", seq);
      printf(".Lelse%d:\n", seq);
      gen(node->els);
      printf(".Lend%d:\n", seq);
    } else {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
      gen(node->then);
      printf(".Lend%d:\n", seq);
    }
    return;
  }
  case ND_WHILE: {
    int seq = labelseq++;
    printf(".Lbegin%d:\n", seq);
    gen(node->cond);
    printf("  pop rax\n");
    printf("  cmp rax, 0\n");
    printf("  je  .Lend%d\n", seq);
    gen(node->then);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_FOR: {
    int seq = labelseq++;
    if (node->init)
      gen(node->init);
    printf(".Lbegin%d:\n", seq);
    if (node->cond) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
    }
    gen(node->then);
    if (node->inc)
      gen(node->inc);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next)
      gen(n);
    return;
  case ND_FUNCALL:
    printf("  call %s\n", node->funcname);
    printf("  push rax\n");
    return;
  case ND_RETURN:
    gen(node->lhs);
    printf("  pop rax\n");
    printf("  jmp .Lreturn\n");
    return;
  }

ノード型がND_FUNCALL(ノードの種類が関数呼び出し)の場合における処理を追加します。
ノード型がND_FUNCALLの場合は、関数呼び出しを行うアセンブリコード”call 関数名”を生成します。
最後に、その関数の戻り値をスタックに退避させるアセンブリコード"push rax"を生成します。

二項演算の対象となる値を得るためのアセンブリコードを生成する(変更なし)
  gen(node->lhs);
  gen(node->rhs);

  printf("  pop rdi\n");
  printf("  pop rax\n");
二項演算を行うアセンブリコードを生成する(変更なし)
  switch (node->kind) {
  case ND_ADD:
    printf("  add rax, rdi\n");
    break;
  case ND_SUB:
    printf("  sub rax, rdi\n");
    break;
  case ND_MUL:
    printf("  imul rax, rdi\n");
    break;
  case ND_DIV:
    printf("  cqo\n");
    printf("  idiv rdi\n");
    break;
  case ND_EQ:
    printf("  cmp rax, rdi\n");
    printf("  sete al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_NE:
    printf("  cmp rax, rdi\n");
    printf("  setne al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LT:
    printf("  cmp rax, rdi\n");
    printf("  setl al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LE:
    printf("  cmp rax, rdi\n");
    printf("  setle al\n");
    printf("  movzb rax, al\n");
    break;
  }

  printf("  push rax\n");
}

追加・修正されたコンパイラソースコード(最大引数6つの関数を呼び出せるコンパイラを作成する)

tokenize関数

https://github.com/rui314/chibicc/commit/5dea368205321bdd55722b7be88efd0bd41b7fb4#diff-289479d6df6940b25dd31a6f2da4881331f916ec642bd1ae47d4ff0a365d8e88R145
https://github.com/rui314/chibicc/blob/5dea368205321bdd55722b7be88efd0bd41b7fb4/tokenize.c#L145

Token *tokenize() {
  char *p = user_input;
  Token head;
  head.next = NULL;
  Token *cur = &head;

  while (*p) {
    // Skip whitespace characters.
    if (isspace(*p)) {
      p++;
      continue;
    }

    // Keyword or multi-letter punctuator
    char *kw = starts_with_reserved(p);
    if (kw) {
      int len = strlen(kw);
      cur = new_token(TK_RESERVED, cur, p, len);
      p += len;
      continue;
    }

    // Single-letter punctuator
    if (strchr("+-*/()<>;={},", *p)) {
      cur = new_token(TK_RESERVED, cur, p++, 1);
      continue;
    }

    // Identifier
    if (is_alpha(*p)) {
      char *q = p++;
      while (is_alnum(*p))
        p++;
      cur = new_token(TK_IDENT, cur, q, p - q);
      continue;
    }

    // Integer literal
    if (isdigit(*p)) {
      cur = new_token(TK_NUM, cur, p, 0);
      char *q = p;
      cur->val = strtol(p, &p, 10);
      cur->len = p - q;
      continue;
    }

    error_at(p, "invalid token");
  }

  new_token(TK_EOF, cur, p, 0);
  return head.next;
}
文字列の先頭アドレスを取得する(変更なし)
  char *p = user_input;
トークンからなる連結リストのヘッダーを作成する(変更なし)
  Token head;
  head.next = NULL;
  Token *cur = &head;
空白文字の場合(変更なし)
    // Skip whitespace characters.
    if (isspace(*p)) {
      p++;
      continue;
    }
キーワードの場合(変更なし)
    // Keyword or multi-letter punctuator
    char *kw = starts_with_reserved(p);
    if (kw) {
      int len = strlen(kw);
      cur = new_token(TK_RESERVED, cur, p, len);
      p += len;
      continue;
    }
1文字の記号の場合
    // Single-letter punctuator
    if (strchr("+-*/()<>;={},", *p)) {
      cur = new_token(TK_RESERVED, cur, p++, 1);
      continue;
    }

第一引数の文字列に" , "を追加して、" , "をトークンとして扱えるようにします。

識別子の場合(変更なし)
    // Identifier
    if (is_alpha(*p)) {
      char *q = p++;
      while (is_alnum(*p))
        p++;
      cur = new_token(TK_IDENT, cur, q, p - q);
      continue;
    }
数字の場合(変更なし)
    // Integer literal
    if (isdigit(*p)) {
      cur = new_token(TK_NUM, cur, p, 0);
      char *q = p;
      cur->val = strtol(p, &p, 10);
      cur->len = p - q;
      continue;
    }
その他の場合(変更なし)
    error_at(p, "invalid token");
トークン列の終端を表すトークンを生成する(変更なし)
  new_token(TK_EOF, cur, p, 0);
連結リストの先頭トークンを戻り値としてリターンする(変更なし)
 return head.next;

Node構造体

https://github.com/rui314/chibicc/commit/5dea368205321bdd55722b7be88efd0bd41b7fb4#diff-d06dbb7ef5899cdf50b340464444680b13aded45363e7aba944dc3551fdf6334R99
https://github.com/rui314/chibicc/blob/5dea368205321bdd55722b7be88efd0bd41b7fb4/chibicc.h#L99

// AST node type
typedef struct Node Node;
struct Node {
  NodeKind kind; // Node kind
  Node *next;    // Next node

  Node *lhs;     // Left-hand side
  Node *rhs;     // Right-hand side

  // "if, "while" or "for" statement
  Node *cond;
  Node *then;
  Node *els;
  Node *init;
  Node *inc;

  // Block
  Node *body;

  // Function call
  char *funcname;
  Node *args;

  Var *var;      // Used if kind == ND_VAR
  int val;       // Used if kind == ND_NUM
};

関数呼び出しのパースの際に使用される子ノードargsを追加します。

primary関数

https://github.com/rui314/chibicc/commit/5dea368205321bdd55722b7be88efd0bd41b7fb4#diff-a07721cd062be25900bddb926de15fc103cf32ea2726d1fea286f6548b810c6aR268
https://github.com/rui314/chibicc/blob/5dea368205321bdd55722b7be88efd0bd41b7fb4/parse.c#L268

Node *primary() {
  if (consume("(")) {
    Node *node = expr();
    expect(")");
    return node;
  }

  Token *tok = consume_ident();
  if (tok) {
    if (consume("(")) {
      Node *node = new_node(ND_FUNCALL);
      node->funcname = strndup(tok->str, tok->len);
      node->args = func_args();
      return node;
    }

    Var *var = find_var(tok);
    if (!var)
      var = push_var(strndup(tok->str, tok->len));
    return new_var(var);
  }

  return new_num(expect_number());
}

primary関数は、生成規則 primary = "(" expr ")" | ident func-args? | num に基づいて、抽象構文木のノードを生成します。

"("、expr、")"(変更なし)
  if (consume("(")) {
    Node *node = expr();
    expect(")");
    return node;
  }
ident、func-argsを0回か1回
  Token *tok = consume_ident();
  if (tok) {
    if (consume("(")) {
      Node *node = new_node(ND_FUNCALL);
      node->funcname = strndup(tok->str, tok->len);
      node->args = func_args();
      return node;
    }

    Var *var = find_var(tok);
    if (!var)
      var = push_var(strndup(tok->str, tok->len));
    return new_var(var);
  }

consume_ident関数を呼び出してtokが真となる場合 → 識別子(ローカル変数、または、関数名)を表現するトークンを取得できた場合の処理です。
consume("(")の戻り値がtrueとなる場合 → 識別子の次のトークンが"("の場合 → 識別子が関数名の場合は、new_node関数を呼び出して関数を呼び出しを表すノードを生成し、生成したノードに関数名を記録します。
最後に、func_args関数を呼び出して抽象構文木を生成し、その抽象構文木の(ルートノードからなる)連結リストを子ノードargsとして登録します。

識別子がローカル変数の場合は、find_var関数を呼び出して、Var構造体からなる連結リストからtokが持つ文字列と同じ名前を持つvar構造体変数を取得します。
!varが真となる場合→tokが持つ文字列と同じ名前を持つvar構造体変数を取得できなかった場合は、push_var関数を呼び出して、var構造体を新規に生成しvar構造体の連結リストに追加します。
最後に、new_var関数を呼び出して、ローカル変数に対応するノードを新規に生成します。

num(変更なし)
  return new_num(expect_number());

func_args関数

https://github.com/rui314/chibicc/commit/5dea368205321bdd55722b7be88efd0bd41b7fb4#diff-a07721cd062be25900bddb926de15fc103cf32ea2726d1fea286f6548b810c6aR241
https://github.com/rui314/chibicc/blob/5dea368205321bdd55722b7be88efd0bd41b7fb4/parse.c#L241

Node *func_args() {
  if (consume(")"))
    return NULL;

  Node *head = assign();
  Node *cur = head;
  while (consume(",")) {
    cur->next = assign();
    cur = cur->next;
  }
  expect(")");
  return head;
}

func_args関数は、生成規則 func-args = "(" (assign ("," assign)*)? ")" に基づいて、抽象構文木のノードを生成します。

"("
  if (consume(")"))
    return NULL;

consume(")")の戻り値がtrueとなる場合 → "関数名("の次のトークンが”)”の場合は、引数なしの関数呼び出しになってしまうので、戻り値をNULLにして処理を終了します。

「assign、『"," assign』を0回以上」を0回か1回
  Node *head = assign();
  Node *cur = head;
  while (consume(",")) {
    cur->next = assign();
    cur = cur->next;
  }

assign関数を呼び出して抽象構文木を生成し、その抽象構文木のルートノードをこれから作成する連結リスト(ノード構造体からなる連結リスト)のヘッダーheadとします。

consume(",")の戻り値がfalseになるまで→”,”を表すトークンが出現するまで、while文のループを継続します。
assign関数を呼び出して抽象構文木を生成し、生成された抽象構文木のルートノードのアドレスを戻り値として取得します。
戻り値として取得した抽象構文木のルートノードのアドレスを連結リストの終端要素のnextメンバに格納し、連結リストの終端要素を表すcurを更新します。

")"
  expect(")");
  return head;

トークンは")"であることが期待されているので、expect関数を呼び出してトークン列を1つ読み進めます。
連結リストの先頭ノード(連結リストのヘッダーの次にあるノード)のアドレスを戻り値としてリターンします。

gen関数

https://github.com/rui314/chibicc/commit/5dea368205321bdd55722b7be88efd0bd41b7fb4#diff-629fe11334ae1d560032cdb6cc6f9a4fbb0f5b1365894b6b648d6ee4d5a654beR105
https://github.com/rui314/chibicc/blob/5dea368205321bdd55722b7be88efd0bd41b7fb4/codegen.c#L105

void gen(Node *node) {
  switch (node->kind) {
  case ND_NUM:
    printf("  push %d\n", node->val);
    return;
  case ND_EXPR_STMT:
    gen(node->lhs);
    printf("  add rsp, 8\n");
    return;
  case ND_VAR:
    gen_addr(node);
    load();
    return;
  case ND_ASSIGN:
    gen_addr(node->lhs);
    gen(node->rhs);
    store();
    return;
  case ND_IF: {
    int seq = labelseq++;
    if (node->els) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lelse%d\n", seq);
      gen(node->then);
      printf("  jmp .Lend%d\n", seq);
      printf(".Lelse%d:\n", seq);
      gen(node->els);
      printf(".Lend%d:\n", seq);
    } else {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
      gen(node->then);
      printf(".Lend%d:\n", seq);
    }
    return;
  }
  case ND_WHILE: {
    int seq = labelseq++;
    printf(".Lbegin%d:\n", seq);
    gen(node->cond);
    printf("  pop rax\n");
    printf("  cmp rax, 0\n");
    printf("  je  .Lend%d\n", seq);
    gen(node->then);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_FOR: {
    int seq = labelseq++;
    if (node->init)
      gen(node->init);
    printf(".Lbegin%d:\n", seq);
    if (node->cond) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
    }
    gen(node->then);
    if (node->inc)
      gen(node->inc);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next)
      gen(n);
    return;
  case ND_FUNCALL: {
    int nargs = 0;
    for (Node *arg = node->args; arg; arg = arg->next) {
      gen(arg);
      nargs++;
    }

    for (int i = nargs - 1; i >= 0; i--)
      printf("  pop %s\n", argreg[i]);

    printf("  call %s\n", node->funcname);
    printf("  push rax\n");
    return;
  }
  case ND_RETURN:
    gen(node->lhs);
    printf("  pop rax\n");
    printf("  jmp .Lreturn\n");
    return;
  }

  gen(node->lhs);
  gen(node->rhs);

  printf("  pop rdi\n");
  printf("  pop rax\n");

  switch (node->kind) {
  case ND_ADD:
    printf("  add rax, rdi\n");
    break;
  case ND_SUB:
    printf("  sub rax, rdi\n");
    break;
  case ND_MUL:
    printf("  imul rax, rdi\n");
    break;
  case ND_DIV:
    printf("  cqo\n");
    printf("  idiv rdi\n");
    break;
  case ND_EQ:
    printf("  cmp rax, rdi\n");
    printf("  sete al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_NE:
    printf("  cmp rax, rdi\n");
    printf("  setne al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LT:
    printf("  cmp rax, rdi\n");
    printf("  setl al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LE:
    printf("  cmp rax, rdi\n");
    printf("  setle al\n");
    printf("  movzb rax, al\n");
    break;
  }

  printf("  push rax\n");
}
二項演算以外を行うアセンブリコードを生成する
  switch (node->kind) {
  case ND_NUM:
    printf("  push %d\n", node->val);
    return;
  case ND_EXPR_STMT:
    gen(node->lhs);
    printf("  add rsp, 8\n");
    return;
  case ND_VAR:
    gen_addr(node);
    load();
    return;
  case ND_ASSIGN:
    gen_addr(node->lhs);
    gen(node->rhs);
    store();
    return;
  case ND_IF: {
    int seq = labelseq++;
    if (node->els) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lelse%d\n", seq);
      gen(node->then);
      printf("  jmp .Lend%d\n", seq);
      printf(".Lelse%d:\n", seq);
      gen(node->els);
      printf(".Lend%d:\n", seq);
    } else {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
      gen(node->then);
      printf(".Lend%d:\n", seq);
    }
    return;
  }
  case ND_WHILE: {
    int seq = labelseq++;
    printf(".Lbegin%d:\n", seq);
    gen(node->cond);
    printf("  pop rax\n");
    printf("  cmp rax, 0\n");
    printf("  je  .Lend%d\n", seq);
    gen(node->then);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_FOR: {
    int seq = labelseq++;
    if (node->init)
      gen(node->init);
    printf(".Lbegin%d:\n", seq);
    if (node->cond) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
    }
    gen(node->then);
    if (node->inc)
      gen(node->inc);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next)
      gen(n);
    return;
  case ND_FUNCALL: {
    int nargs = 0;
    for (Node *arg = node->args; arg; arg = arg->next) {
      gen(arg);
      nargs++;
    }

    for (int i = nargs - 1; i >= 0; i--)
      printf("  pop %s\n", argreg[i]);

    printf("  call %s\n", node->funcname);
    printf("  push rax\n");
    return;
  }
  case ND_RETURN:
    gen(node->lhs);
    printf("  pop rax\n");
    printf("  jmp .Lreturn\n");
    return;
  }

ノードの型がND_FUNCALL(ノードの種類が関数呼び出し)の場合における処理を修正します。
抽象構文木のルートノードからなる連結リストargsを用いてgen関数を呼び出し、引数値(ノードargの結果)を得るために必要なアセンブリコードを生成します。
この時、連結リストargsの最後の要素を指定するインデックスを取得するために、変数nargsをインクリメントしておきます。

生成されたアセンブリコードの実行時を考慮すると、関数で使用される引数値はスタックに退避されている状態なので、インデックスnargs、配列argreg、for文を使って、引数の個数分だけアセンブリコード" pop レジスタ名"を生成します。

argregはレジスタ名の要素からなる配列で、何番目の引数値をどのレジスタにセットするかはABIの一部である関数呼び出し規約によって定められています。
https://github.com/rui314/chibicc/commit/5dea368205321bdd55722b7be88efd0bd41b7fb4#diff-629fe11334ae1d560032cdb6cc6f9a4fbb0f5b1365894b6b648d6ee4d5a654beR4
https://github.com/rui314/chibicc/blob/5dea368205321bdd55722b7be88efd0bd41b7fb4/codegen.c#L4

char *argreg[] = {"rdi", "rsi", "rdx", "rcx", "r8", "r9"};

最後に、関数呼び出しを行うアセンブリコード”call 関数名”とその関数の戻り値をスタックに退避させるアセンブリコード"push rax"を生成します。

二項演算の対象となる値を得るためのアセンブリコードを生成する(変更なし)
  gen(node->lhs);
  gen(node->rhs);

  printf("  pop rdi\n");
  printf("  pop rax\n");
二項演算を行うアセンブリコードを生成する(変更なし)
  switch (node->kind) {
  case ND_ADD:
    printf("  add rax, rdi\n");
    break;
  case ND_SUB:
    printf("  sub rax, rdi\n");
    break;
  case ND_MUL:
    printf("  imul rax, rdi\n");
    break;
  case ND_DIV:
    printf("  cqo\n");
    printf("  idiv rdi\n");
    break;
  case ND_EQ:
    printf("  cmp rax, rdi\n");
    printf("  sete al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_NE:
    printf("  cmp rax, rdi\n");
    printf("  setne al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LT:
    printf("  cmp rax, rdi\n");
    printf("  setl al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LE:
    printf("  cmp rax, rdi\n");
    printf("  setle al\n");
    printf("  movzb rax, al\n");
    break;
  }

  printf("  push rax\n");

追加・修正されたコンパイラソースコード(関数呼び出しの前にRSPが16の倍数になっているコンパイラを作成する)

gen関数

https://github.com/rui314/chibicc/commit/aedbf56c3af4914e3f183223ff879734683bec73#diff-629fe11334ae1d560032cdb6cc6f9a4fbb0f5b1365894b6b648d6ee4d5a654beR115
https://github.com/rui314/chibicc/blob/aedbf56c3af4914e3f183223ff879734683bec73/codegen.c#L115

void gen(Node *node) {
  switch (node->kind) {
  case ND_NUM:
    printf("  push %d\n", node->val);
    return;
  case ND_EXPR_STMT:
    gen(node->lhs);
    printf("  add rsp, 8\n");
    return;
  case ND_VAR:
    gen_addr(node);
    load();
    return;
  case ND_ASSIGN:
    gen_addr(node->lhs);
    gen(node->rhs);
    store();
    return;
  case ND_IF: {
    int seq = labelseq++;
    if (node->els) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lelse%d\n", seq);
      gen(node->then);
      printf("  jmp .Lend%d\n", seq);
      printf(".Lelse%d:\n", seq);
      gen(node->els);
      printf(".Lend%d:\n", seq);
    } else {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
      gen(node->then);
      printf(".Lend%d:\n", seq);
    }
    return;
  }
  case ND_WHILE: {
    int seq = labelseq++;
    printf(".Lbegin%d:\n", seq);
    gen(node->cond);
    printf("  pop rax\n");
    printf("  cmp rax, 0\n");
    printf("  je  .Lend%d\n", seq);
    gen(node->then);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_FOR: {
    int seq = labelseq++;
    if (node->init)
      gen(node->init);
    printf(".Lbegin%d:\n", seq);
    if (node->cond) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
    }
    gen(node->then);
    if (node->inc)
      gen(node->inc);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next)
      gen(n);
    return;
  case ND_FUNCALL: {
    int nargs = 0;
    for (Node *arg = node->args; arg; arg = arg->next) {
      gen(arg);
      nargs++;
    }

    for (int i = nargs - 1; i >= 0; i--)
      printf("  pop %s\n", argreg[i]);

    // We need to align RSP to a 16 byte boundary before
    // calling a function because it is an ABI requirement.
    // RAX is set to 0 for variadic function.
    int seq = labelseq++;
    printf("  mov rax, rsp\n");
    printf("  and rax, 15\n");
    printf("  jnz .Lcall%d\n", seq);
    printf("  mov rax, 0\n");
    printf("  call %s\n", node->funcname);
    printf("  jmp .Lend%d\n", seq);
    printf(".Lcall%d:\n", seq);
    printf("  sub rsp, 8\n");
    printf("  mov rax, 0\n");
    printf("  call %s\n", node->funcname);
    printf("  add rsp, 8\n");
    printf(".Lend%d:\n", seq);
    printf("  push rax\n");
    return;
  }
  case ND_RETURN:
    gen(node->lhs);
    printf("  pop rax\n");
    printf("  jmp .Lreturn\n");
    return;
  }

  gen(node->lhs);
  gen(node->rhs);

  printf("  pop rdi\n");
  printf("  pop rax\n");

  switch (node->kind) {
  case ND_ADD:
    printf("  add rax, rdi\n");
    break;
  case ND_SUB:
    printf("  sub rax, rdi\n");
    break;
  case ND_MUL:
    printf("  imul rax, rdi\n");
    break;
  case ND_DIV:
    printf("  cqo\n");
    printf("  idiv rdi\n");
    break;
  case ND_EQ:
    printf("  cmp rax, rdi\n");
    printf("  sete al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_NE:
    printf("  cmp rax, rdi\n");
    printf("  setne al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LT:
    printf("  cmp rax, rdi\n");
    printf("  setl al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LE:
    printf("  cmp rax, rdi\n");
    printf("  setle al\n");
    printf("  movzb rax, al\n");
    break;
  }

  printf("  push rax\n");
}
二項演算以外を行うアセンブリコードを生成する
  switch (node->kind) {
  case ND_NUM:
    printf("  push %d\n", node->val);
    return;
  case ND_EXPR_STMT:
    gen(node->lhs);
    printf("  add rsp, 8\n");
    return;
  case ND_VAR:
    gen_addr(node);
    load();
    return;
  case ND_ASSIGN:
    gen_addr(node->lhs);
    gen(node->rhs);
    store();
    return;
  case ND_IF: {
    int seq = labelseq++;
    if (node->els) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lelse%d\n", seq);
      gen(node->then);
      printf("  jmp .Lend%d\n", seq);
      printf(".Lelse%d:\n", seq);
      gen(node->els);
      printf(".Lend%d:\n", seq);
    } else {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
      gen(node->then);
      printf(".Lend%d:\n", seq);
    }
    return;
  }
  case ND_WHILE: {
    int seq = labelseq++;
    printf(".Lbegin%d:\n", seq);
    gen(node->cond);
    printf("  pop rax\n");
    printf("  cmp rax, 0\n");
    printf("  je  .Lend%d\n", seq);
    gen(node->then);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_FOR: {
    int seq = labelseq++;
    if (node->init)
      gen(node->init);
    printf(".Lbegin%d:\n", seq);
    if (node->cond) {
      gen(node->cond);
      printf("  pop rax\n");
      printf("  cmp rax, 0\n");
      printf("  je  .Lend%d\n", seq);
    }
    gen(node->then);
    if (node->inc)
      gen(node->inc);
    printf("  jmp .Lbegin%d\n", seq);
    printf(".Lend%d:\n", seq);
    return;
  }
  case ND_BLOCK:
    for (Node *n = node->body; n; n = n->next)
      gen(n);
    return;
  case ND_FUNCALL: {
    int nargs = 0;
    for (Node *arg = node->args; arg; arg = arg->next) {
      gen(arg);
      nargs++;
    }

    for (int i = nargs - 1; i >= 0; i--)
      printf("  pop %s\n", argreg[i]);

    // We need to align RSP to a 16 byte boundary before
    // calling a function because it is an ABI requirement.
    // RAX is set to 0 for variadic function.
    int seq = labelseq++;
    printf("  mov rax, rsp\n");
    printf("  and rax, 15\n");
    printf("  jnz .Lcall%d\n", seq);
    printf("  mov rax, 0\n");
    printf("  call %s\n", node->funcname);
    printf("  jmp .Lend%d\n", seq);
    printf(".Lcall%d:\n", seq);
    printf("  sub rsp, 8\n");
    printf("  mov rax, 0\n");
    printf("  call %s\n", node->funcname);
    printf("  add rsp, 8\n");
    printf(".Lend%d:\n", seq);
    printf("  push rax\n");
    return;
  }
  case ND_RETURN:
    gen(node->lhs);
    printf("  pop rax\n");
    printf("  jmp .Lreturn\n");
    return;
  }

ノードの型がND_FUNCALL(ノードの種類が関数呼び出し)の場合における処理に、RSPの値を16バイト境界にする処理を追加します。


RSPの値が16バイト境界であるかを判定する処理

①printf(" mov rax, rsp\n");
RSPの値をRAXにコピーするアセンブリコードを生成します。

②printf(" and rax, 15\n");
RAXの値(RSPの値)と0x0000 0000 0000 000Fの各bitをAND演算し、その演算結果をRAXにセットします。
push命令やpop命令はRSPを8バイト単位で変更するので、RSPの下位8bitは 0x0 か 0x8 のどちらかの値となります(RSPの値は 0x???? ???? ???? ???0 か 0x???? ???? ???? ???8 のどちらかの値となります)。
演算結果が0ではない時、RSPの下位8bitは0x8(RSPの値は 0x???? ???? ???? ???8) → RSPの値は16バイト境界ではありません。
演算結果が0の時、RSPの下位8bitは0x0(RSPの値は 0x???? ???? ???? ???0) → RSPの値は16バイト境界となります。

③printf(" jnz .Lcall%d\n", seq);
演算結果が0ではない時(RSPの値が16バイト境界ではない時)にジャンプするアセンブリコードを生成します。


RSPの値が16バイト境界である場合の処理

④printf(" mov rax, 0\n");
関数呼び出しが終わった後、RAXにはその関数の戻り値がセットされるので、RAXを0で初期化するアセンブリコードを生成します。

⑤printf(" call %s\n", node->funcname);
関数呼び出しを行うアセンブリコードを生成します。

⑥printf(" jmp .Lend%d\n", seq);
戻り値をスタックへ退避させる処理へジャンプするアセンブリコードを生成します。


RSPの値が16バイト境界ではない場合の処理

⑦printf(".Lcall%d:\n", seq);
RSPの値が16バイト境界ではない場合に使用されるラベルを生成します。

⑧printf(" sub rsp, 8\n");
RSPから8を減算して(スタックのトップアドレスを下位方向へ8バイト移動させて)、RSPの値を16バイト境界にするアセンブリコードを生成します。

⑨printf(" mov rax, 0\n");
関数呼び出しが終わった後、RAXにはその関数の戻り値がセットされるので、RAXを0で初期化するアセンブリコードを生成します。

⑩printf(" call %s\n", node->funcname);
関数呼び出しを行うアセンブリコードを生成します。

⑪printf(" add rsp, 8\n");
RSPに8を加算して(スタックのトップアドレスを上位方向へ8バイト移動させて)、RSPの値を元の状態に戻すアセンブリコードを生成します。


戻り値をスタックへ退避させる処理
⑫printf(".Lend%d:\n", seq);
戻り値をスタックへ退避させる処理へジャンプする時に使用されるラベルを生成します。

⑬printf(" push rax\n");
戻り値をスタックへ退避させるアセンブリコードを生成します。

二項演算の対象となる値を得るためのアセンブリコードを生成する(変更なし)
  gen(node->lhs);
  gen(node->rhs);

  printf("  pop rdi\n");
  printf("  pop rax\n");
二項演算を行うアセンブリコードを生成する(変更なし)
  switch (node->kind) {
  case ND_ADD:
    printf("  add rax, rdi\n");
    break;
  case ND_SUB:
    printf("  sub rax, rdi\n");
    break;
  case ND_MUL:
    printf("  imul rax, rdi\n");
    break;
  case ND_DIV:
    printf("  cqo\n");
    printf("  idiv rdi\n");
    break;
  case ND_EQ:
    printf("  cmp rax, rdi\n");
    printf("  sete al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_NE:
    printf("  cmp rax, rdi\n");
    printf("  setne al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LT:
    printf("  cmp rax, rdi\n");
    printf("  setl al\n");
    printf("  movzb rax, al\n");
    break;
  case ND_LE:
    printf("  cmp rax, rdi\n");
    printf("  setle al\n");
    printf("  movzb rax, al\n");
    break;
  }

  printf("  push rax\n");
}

テストコード

https://github.com/rui314/chibicc/blob/aedbf56c3af4914e3f183223ff879734683bec73/test.sh

#!/bin/bash
cat <<EOF | gcc -xc -c -o tmp2.o -
int ret3() { return 3; }
int ret5() { return 5; }
int add(int x, int y) { return x+y; }
int sub(int x, int y) { return x-y; }
int add6(int a, int b, int c, int d, int e, int f) {
  return a+b+c+d+e+f;
}
EOF

assert() {
  expected="$1"
  input="$2"

  ./chibicc "$input" > tmp.s
  gcc -static -o tmp tmp.s tmp2.o
  ./tmp
  actual="$?"

  if [ "$actual" = "$expected" ]; then
    echo "$input => $actual"
  else
    echo "$input => $expected expected, but got $actual"
    exit 1
  fi
}

assert 0 'return 0;'
assert 42 'return 42;'
assert 21 'return 5+20-4;'
assert 41 'return  12 + 34 - 5 ;'
assert 47 'return 5+6*7;'
assert 15 'return 5*(9-6);'
assert 4 'return (3+5)/2;'
assert 10 'return -10+20;'
assert 10 'return - -10;'
assert 10 'return - - +10;'

assert 0 'return 0==1;'
assert 1 'return 42==42;'
assert 1 'return 0!=1;'
assert 0 'return 42!=42;'

assert 1 'return 0<1;'
assert 0 'return 1<1;'
assert 0 'return 2<1;'
assert 1 'return 0<=1;'
assert 1 'return 1<=1;'
assert 0 'return 2<=1;'

assert 1 'return 1>0;'
assert 0 'return 1>1;'
assert 0 'return 1>2;'
assert 1 'return 1>=0;'
assert 1 'return 1>=1;'
assert 0 'return 1>=2;'

assert 3 'a=3; return a;'
assert 8 'a=3; z=5; return a+z;'

assert 1 'return 1; 2; 3;'
assert 2 '1; return 2; 3;'
assert 3 '1; 2; return 3;'

assert 3 'foo=3; return foo;'
assert 8 'foo123=3; bar=5; return foo123+bar;'

assert 3 'if (0) return 2; return 3;'
assert 3 'if (1-1) return 2; return 3;'
assert 2 'if (1) return 2; return 3;'
assert 2 'if (2-1) return 2; return 3;'

assert 3 '{1; {2;} return 3;}'

assert 10 'i=0; while(i<10) i=i+1; return i;'
assert 55 'i=0; j=0; while(i<=10) {j=i+j; i=i+1;} return j;'

assert 55 'i=0; j=0; for (i=0; i<=10; i=i+1) j=i+j; return j;'
assert 3 'for (;;) return 3; return 5;'

assert 3 'return ret3();'
assert 5 'return ret5();'
assert 8 'return add(3, 5);'
assert 2 'return sub(5, 3);'
assert 21 'return add6(1,2,3,4,5,6);'

echo OK

Makefile

https://github.com/rui314/chibicc/blob/aedbf56c3af4914e3f183223ff879734683bec73/Makefile

CFLAGS=-std=c11 -g -static
SRCS=$(wildcard *.c)
OBJS=$(SRCS:.c=.o)

chibicc: $(OBJS)
	$(CC) -o $@ $(OBJS) $(LDFLAGS)

$(OBJS): chibicc.h

test: chibicc
	./test.sh

clean:
	rm -f chibicc *.o *~ tmp*

.PHONY: test clean