From aa2907df297fdfa57e608c110702579303193641 Mon Sep 17 00:00:00 2001
From: Zhou Chengming <zhouchengming1@huawei.com>
Date: Sat, 4 Feb 2017 09:31:25 +0800
Subject: [PATCH] support dup file+symbol

We use kelf_base->symbols to find a unique matching FILE+locals combination
when we call lookup_open(). If we can't find one matching or we find more
than one matching, we error out.

If we find a unique one, we setup table->local_syms in lookup_open(),
so later lookup_local_symbol() could do its lookup based on table->local_syms.

Fixes #604.

Suggested-by: Josh Poimboeuf <jpoimboe@redhat.com>
Signed-off-by: Zhou Chengming <zhouchengming1@huawei.com>
---
 kpatch-build/create-diff-object.c |  73 +++++++++++++++-----
 kpatch-build/lookup.c             | 109 ++++++++++++++++++++++--------
 kpatch-build/lookup.h             |  10 ++-
 3 files changed, 144 insertions(+), 48 deletions(-)

diff --git a/kpatch-build/create-diff-object.c b/kpatch-build/create-diff-object.c
index e28e168..69302f2 100644
--- a/kpatch-build/create-diff-object.c
+++ b/kpatch-build/create-diff-object.c
@@ -1836,6 +1836,43 @@ static void kpatch_process_special_sections(struct kpatch_elf *kelf)
 	}
 }
 
+static struct sym_compare_type *kpatch_elf_locals(struct kpatch_elf *kelf)
+{
+	struct symbol *sym;
+	int i = 0, sym_num = 0;
+	struct sym_compare_type *sym_array;
+
+	list_for_each_entry(sym, &kelf->symbols, list) {
+		if (sym->bind != STB_LOCAL)
+			continue;
+		if (sym->type != STT_FUNC && sym->type != STT_OBJECT)
+			continue;
+
+		sym_num++;
+	}
+
+	if (!sym_num)
+		return NULL;
+
+	sym_array = malloc((sym_num + 1) * sizeof(struct sym_compare_type));
+	if (!sym_array)
+		ERROR("malloc");
+
+	list_for_each_entry(sym, &kelf->symbols, list) {
+		if (sym->bind != STB_LOCAL)
+			continue;
+		if (sym->type != STT_FUNC && sym->type != STT_OBJECT)
+			continue;
+
+		sym_array[i].type = sym->type;
+		sym_array[i++].name = sym->name;
+	}
+	sym_array[i].type = 0;
+	sym_array[i].name = NULL;
+
+	return sym_array;
+}
+
 static void kpatch_create_patches_sections(struct kpatch_elf *kelf,
 					   struct lookup_table *table, char *hint,
 					   char *objname)
@@ -1872,7 +1909,7 @@ static void kpatch_create_patches_sections(struct kpatch_elf *kelf,
 		if (sym->type == STT_FUNC && sym->status == CHANGED) {
 			if (sym->bind == STB_LOCAL) {
 				if (lookup_local_symbol(table, sym->name,
-				                        hint, &result))
+				                        &result))
 					ERROR("lookup_local_symbol %s (%s)",
 					      sym->name, hint);
 			} else {
@@ -2034,11 +2071,8 @@ static void kpatch_create_intermediate_sections(struct kpatch_elf *kelf,
 			if (rela->sym->bind == STB_LOCAL) {
 				/* An unchanged local symbol */
 				ret = lookup_local_symbol(table,
-					rela->sym->name, hint, &result);
-				if (ret == 2)
-					ERROR("lookup_local_symbol: ambiguous %s:%s relocation, needed for %s",
-			               hint, rela->sym->name, sec->base->name);
-				else if (ret)
+					rela->sym->name, &result);
+				if (ret)
 					ERROR("lookup_local_symbol %s:%s needed for %s",
 			               hint, rela->sym->name, sec->base->name);
 
@@ -2417,6 +2451,7 @@ int main(int argc, char *argv[])
 	struct symbol *sym;
 	char *hint = NULL, *objname, *pos;
 	char *mod_symvers_path, *pmod_name;
+	struct sym_compare_type *base_locals;
 
 	arguments.debug = 0;
 	argp_parse (&argp, argc, argv, 0, NULL, &arguments);
@@ -2437,6 +2472,20 @@ int main(int argc, char *argv[])
 	kpatch_check_program_headers(kelf_base->elf);
 	kpatch_check_program_headers(kelf_patched->elf);
 
+	list_for_each_entry(sym, &kelf_base->symbols, list) {
+		if (sym->type == STT_FILE) {
+			hint = sym->name;
+			break;
+		}
+	}
+	if (!hint)
+		ERROR("FILE symbol not found in base. Stripped?\n");
+
+	/* create symbol lookup table */
+	base_locals = kpatch_elf_locals(kelf_base);
+	lookup = lookup_open(arguments.args[2], mod_symvers_path, hint, base_locals);
+	free(base_locals);
+
 	kpatch_mark_grouped_sections(kelf_patched);
 	kpatch_replace_sections_syms(kelf_base);
 	kpatch_replace_sections_syms(kelf_patched);
@@ -2492,18 +2541,6 @@ int main(int argc, char *argv[])
 	 */
 	kpatch_elf_teardown(kelf_patched);
 
-	list_for_each_entry(sym, &kelf_out->symbols, list) {
-		if (sym->type == STT_FILE) {
-			hint = sym->name;
-			break;
-		}
-	}
-	if (!hint)
-		ERROR("FILE symbol not found in output. Stripped?\n");
-
-	/* create symbol lookup table */
-	lookup = lookup_open(arguments.args[2], mod_symvers_path);
-
 	/* extract module name (destructive to arguments.modulefile) */
 	objname = basename(arguments.args[2]);
 	if (!strncmp(objname, "vmlinux-", 8))
diff --git a/kpatch-build/lookup.c b/kpatch-build/lookup.c
index a0a9c93..2be4bd5 100644
--- a/kpatch-build/lookup.c
+++ b/kpatch-build/lookup.c
@@ -55,6 +55,7 @@ struct lookup_table {
 	int obj_nr, exp_nr;
 	struct object_symbol *obj_syms;
 	struct export_symbol *exp_syms;
+	struct object_symbol *local_syms;
 };
 
 #define for_each_obj_symbol(ndx, iter, table) \
@@ -63,6 +64,55 @@ struct lookup_table {
 #define for_each_exp_symbol(ndx, iter, table) \
 	for (ndx = 0, iter = table->exp_syms; ndx < table->exp_nr; ndx++, iter++)
 
+static void find_local_syms(struct lookup_table *table, char *hint,
+			    struct sym_compare_type *locals)
+{
+	struct object_symbol *sym, *file_sym;
+	int i, in_file = 0;
+	struct sym_compare_type *local_index;
+
+	for_each_obj_symbol(i, sym, table) {
+		if (sym->type == STT_FILE) {
+			if (in_file && !local_index->name) {
+				if (table->local_syms)
+					ERROR("find_local_syms for %s: found_dup", hint);
+				table->local_syms = file_sym;
+			}
+
+			if (!strcmp(hint, sym->name)) {
+				in_file = 1;
+				file_sym = sym;
+				local_index = locals;
+			}
+			else
+				in_file = 0;
+
+			continue;
+		}
+
+		if (!in_file)
+			continue;
+		if (sym->bind != STB_LOCAL || (sym->type != STT_FUNC && sym->type != STT_OBJECT))
+			continue;
+
+		if (local_index->name &&
+		    local_index->type == sym->type &&
+		    !strcmp(local_index->name, sym->name))
+			local_index++;
+		else
+			in_file = 0;
+	}
+
+	if (in_file && !local_index->name) {
+		if (table->local_syms)
+			ERROR("find_local_syms for %s: found_dup", hint);
+		table->local_syms = file_sym;
+	}
+
+	if (!table->local_syms)
+		ERROR("find_local_syms for %s: found_none", hint);
+}
+
 static void obj_read(struct lookup_table *table, char *path)
 {
 	Elf *elf;
@@ -203,7 +253,8 @@ static void symvers_read(struct lookup_table *table, char *path)
 	fclose(file);
 }
 
-struct lookup_table *lookup_open(char *obj_path, char *symvers_path)
+struct lookup_table *lookup_open(char *obj_path, char *symvers_path,
+				 char *hint, struct sym_compare_type *locals)
 {
 	struct lookup_table *table;
 
@@ -215,6 +266,10 @@ struct lookup_table *lookup_open(char *obj_path, char *symvers_path)
 	obj_read(table, obj_path);
 	symvers_read(table, symvers_path);
 
+	table->local_syms = NULL;
+	if (locals)
+		find_local_syms(table, hint, locals);
+
 	return table;
 }
 
@@ -225,40 +280,38 @@ void lookup_close(struct lookup_table *table)
 	free(table);
 }
 
-int lookup_local_symbol(struct lookup_table *table, char *name, char *hint,
+int lookup_local_symbol(struct lookup_table *table, char *name,
                         struct lookup_result *result)
 {
-	struct object_symbol *sym, *match = NULL;
-	int i;
+	struct object_symbol *sym;
 	unsigned long pos = 0;
-	char *curfile = NULL;
+	int i, match = 0, in_file = 0;
+
+	if (!table->local_syms)
+		return 1;
 
 	memset(result, 0, sizeof(*result));
 	for_each_obj_symbol(i, sym, table) {
-		if (sym->type == STT_FILE) {
-			if (!strcmp(sym->name, hint)) {
-				curfile = sym->name;
-				continue; /* begin hint file symbols */
-			} else if (curfile)
-				curfile = NULL; /* end hint file symbols */
+		if (sym->skip)
+			continue;
+
+		if (sym->bind == STB_LOCAL && !strcmp(sym->name, name))
+			pos++;
+
+		if (table->local_syms == sym) {
+			in_file = 1;
+			continue;
 		}
-		if (sym->bind == STB_LOCAL) {
-			if (sym->name && !strcmp(sym->name, name)) {
-				/*
-				 * need to count any occurrence of the symbol
-				 * name, unless we've already found a match
-				 */
-				if (!match)
-					pos++;
 
-				if (!curfile)
-					continue;
+		if (!in_file)
+			continue;
 
-				if (match)
-					/* dup file+symbol, unresolvable ambiguity */
-					return 2;
-				match = sym;
-			}
+		if (sym->type == STT_FILE)
+			break;
+
+		if (sym->bind == STB_LOCAL && !strcmp(sym->name, name)) {
+			match = 1;
+			break;
 		}
 	}
 
@@ -266,8 +319,8 @@ int lookup_local_symbol(struct lookup_table *table, char *name, char *hint,
 		return 1;
 
 	result->pos = pos;
-	result->value = match->value;
-	result->size = match->size;
+	result->value = sym->value;
+	result->size = sym->size;
 	return 0;
 }
 
diff --git a/kpatch-build/lookup.h b/kpatch-build/lookup.h
index 821103d..327a9d5 100644
--- a/kpatch-build/lookup.h
+++ b/kpatch-build/lookup.h
@@ -9,9 +9,15 @@ struct lookup_result {
 	unsigned long pos;
 };
 
-struct lookup_table *lookup_open(char *obj_path, char *symvers_path);
+struct sym_compare_type {
+	char *name;
+	int type;
+};
+
+struct lookup_table *lookup_open(char *obj_path, char *symvers_path,
+				 char *hint, struct sym_compare_type *locals);
 void lookup_close(struct lookup_table *table);
-int lookup_local_symbol(struct lookup_table *table, char *name, char *hint,
+int lookup_local_symbol(struct lookup_table *table, char *name,
                         struct lookup_result *result);
 int lookup_global_symbol(struct lookup_table *table, char *name,
                          struct lookup_result *result);